├── res ├── 5-10.png ├── 5-4.png └── 5-8.png ├── experiments ├── 0 │ ├── 01.jpg │ └── 01.png ├── 1 │ ├── 17.jpg │ └── 17.png ├── 2 │ ├── 92.jpg │ └── 92.png ├── 3 │ ├── 3.jpg │ └── 3.png ├── 4 │ ├── 40c044db053a4c486421b4123ccdc542.jpg │ └── 40c044db053a4c486421b4123ccdc542.png ├── utils.py ├── cal_pca.py ├── cal_moments.py └── cal_histogram.py ├── model ├── __init__.py ├── mobilenet.py ├── flops.py ├── fast_scnn.py ├── lednet.py ├── enet.py ├── dfanet.py └── hlnet.py ├── .gitignore ├── README.md ├── benchmark.py ├── metric.py ├── test.py ├── data_loader.py ├── train.py ├── pipline_test.py └── LICENSE /res/5-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/res/5-10.png -------------------------------------------------------------------------------- /res/5-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/res/5-4.png -------------------------------------------------------------------------------- /res/5-8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/res/5-8.png -------------------------------------------------------------------------------- /experiments/0/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/0/01.jpg -------------------------------------------------------------------------------- /experiments/0/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/0/01.png -------------------------------------------------------------------------------- /experiments/1/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/1/17.jpg -------------------------------------------------------------------------------- /experiments/1/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/1/17.png -------------------------------------------------------------------------------- /experiments/2/92.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/2/92.jpg -------------------------------------------------------------------------------- /experiments/2/92.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/2/92.png -------------------------------------------------------------------------------- /experiments/3/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/3/3.jpg -------------------------------------------------------------------------------- /experiments/3/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/3/3.png -------------------------------------------------------------------------------- /experiments/4/40c044db053a4c486421b4123ccdc542.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/4/40c044db053a4c486421b4123ccdc542.jpg -------------------------------------------------------------------------------- /experiments/4/40c044db053a4c486421b4123ccdc542.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKYLUO1991/Face-skin-hair-segmentaiton-and-skin-color-evaluation/HEAD/experiments/4/40c044db053a4c486421b4123ccdc542.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2020/3/27 20:11 4 | # @Author : JackyLUO 5 | # @E-mail : lingluo@stumail.neu.edu.cn 6 | # @Site : 7 | # @File : __init__.py 8 | # @Software: PyCharm -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from sklearn.metrics import confusion_matrix 4 | 5 | 6 | def plot_confusion_matrix(y_true, y_pred, classes, 7 | normalize=False, 8 | title=None, 9 | cmap=plt.cm.Blues): 10 | """ 11 | This function prints and plots the confusion matrix. 12 | Normalization can be applied by setting `normalize=True`. 13 | """ 14 | if not title: 15 | if normalize: 16 | title = 'Normalized confusion matrix' 17 | else: 18 | title = 'Confusion matrix, without normalization' 19 | 20 | # Compute confusion matrix 21 | cm = confusion_matrix(y_true, y_pred) 22 | if normalize: 23 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 24 | print("Normalized confusion matrix") 25 | else: 26 | print('Confusion matrix, without normalization') 27 | 28 | fig, ax = plt.subplots() 29 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 30 | ax.figure.colorbar(im, ax=ax) 31 | # We want to show all ticks... 32 | ax.set(xticks=np.arange(cm.shape[1]), 33 | yticks=np.arange(cm.shape[0]), 34 | # ... and label them with the respective list entries 35 | xticklabels=classes, yticklabels=classes, 36 | title=title, 37 | ylabel='True label', 38 | xlabel='Predicted label') 39 | 40 | # Rotate the tick labels and set their alignment. 41 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 42 | rotation_mode="anchor") 43 | 44 | # Loop over data dimensions and create text annotations. 45 | fmt = '.2f' if normalize else 'd' 46 | thresh = cm.max() / 2. 47 | for i in range(cm.shape[0]): 48 | for j in range(cm.shape[1]): 49 | ax.text(j, i, format(cm[i, j], fmt), 50 | ha="center", va="center", 51 | color="white" if cm[i, j] > thresh else "black") 52 | fig.tight_layout() 53 | return ax 54 | 55 | 56 | class Histogram: 57 | '''Histogram base class''' 58 | 59 | def __init__(self, bins): 60 | self.bins = bins 61 | 62 | def describe(self, image, mask): 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## HLNet: A Unified Framework for Real-Time Segmentation and Facial Skin Tones Evaluation 2 | 3 | ## Abstract: 4 | Real-time semantic segmentation plays a crucial role in industrial applications, such as 5 | autonomous driving, the beauty industry, and so on. It is a challenging problem to balance the 6 | relationship between speed and segmentation performance. To address such a complex task, this 7 | paper introduces an efficient convolutional neural network (CNN) architecture named HLNet for 8 | devices with limited resources. Based on high-quality design modules, HLNet better integrates 9 | high-dimensional and low-dimensional information while obtaining sufficient receptive fields, which 10 | achieves remarkable results on three benchmark datasets. To our knowledge, the accuracy of skin 11 | tone classification is usually unsatisfactory due to the influence of external environmental factors such 12 | as illumination and background impurities. Therefore, we use HLNet to obtain accurate face regions, 13 | and further use color moment algorithm to extract its color features. Specifically, for a 224 × 224 14 | input, using our HLNet, we achieve 78.39% mean IoU on Figaro1k dataset at over 17 FPS in the case 15 | of the CPU environment. We further use the masked color moment for skin tone grade evaluation 16 | and approximate 80% classification accuracy demonstrate the feasibility of the proposed method. 17 | 18 | ## The latest open source work: 19 | https://github.com/JACKYLUO1991/FaceParsing. 20 | 21 | ## **Problem correction:** 22 | *It is worth noting that some training sets are mistaken for test sets in image file copying, which leads to high results in arXiv. The current version has been corrected.* 23 | 24 | ## Demos 25 |
26 | raw
27 | 28 | ## Please cited: 29 | ``` 30 | @article{feng2020hlnet, 31 | title={HLNet: A Unified Framework for Real-Time Segmentation and Facial Skin Tones Evaluation}, 32 | author={Feng, Xinglong and Gao, Xianwen and Luo, Ling}, 33 | journal={Symmetry}, 34 | volume={12}, 35 | number={11}, 36 | pages={1812}, 37 | year={2020}, 38 | publisher={Multidisciplinary Digital Publishing Institute} 39 | } 40 | ``` 41 | 42 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 4 | 5 | import time 6 | import numpy as np 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | from model.hlnet import HLNet 11 | from model.dfanet import DFANet 12 | from model.enet import ENet 13 | from model.lednet import LEDNet 14 | from model.mobilenet import MobileNet 15 | from model.fast_scnn import Fast_SCNN 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--image_size", '-i', 19 | help="image size", type=int, default=256) 20 | parser.add_argument("--batch_size", '-b', 21 | help="batch size", type=int, default=3) 22 | parser.add_argument("--model_name", help="model's name", 23 | choices=['hlnet', 'fastscnn', 'lednet', 'dfanet', 'enet', 'mobilenet'], 24 | type=str, default='hlnet') 25 | parser.add_argument("--nums", help="output num", 26 | type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | IMG_SIZE = args.image_size 30 | CLS_NUM = args.nums 31 | 32 | 33 | def get_model(name): 34 | if name == 'hlnet': 35 | model = HLNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) 36 | elif name == 'fastscnn': 37 | model = Fast_SCNN(num_classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model() 38 | elif name == 'lednet': 39 | model = LEDNet(groups=2, classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model() 40 | elif name == 'dfanet': 41 | model = DFANet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM, size_factor=2) 42 | elif name == 'enet': 43 | model = ENet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) 44 | elif name == 'mobilenet': 45 | model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) 46 | else: 47 | raise NameError("No corresponding model...") 48 | 49 | return model 50 | 51 | 52 | def main(): 53 | """Benchmark your model in your local pc.""" 54 | 55 | model = get_model(args.model_name) 56 | inputs = np.random.randn(args.batch_size, args.image_size, args.image_size, 3) 57 | 58 | time_per_batch = [] 59 | 60 | for i in tqdm(range(500)): 61 | start = time.time() 62 | model.predict(inputs, batch_size=args.batch_size) 63 | elapsed = time.time() - start 64 | time_per_batch.append(elapsed) 65 | 66 | time_per_batch = np.array(time_per_batch) 67 | 68 | # Remove the first item 69 | print(time_per_batch[1:].mean()) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import tensorflow as tf 3 | from keras.utils.generic_utils import get_custom_objects 4 | 5 | CLS_NUM = 2 # should be modified according to class number 6 | 7 | SMOOTH = K.epsilon() 8 | 9 | # https: // blog.csdn.net/majinlei121/article/details/78965435 10 | def mean_iou(y_true, y_pred, cls_num=CLS_NUM): 11 | result = 0 12 | nc = tf.cast(tf.shape(y_true)[-1], tf.float32) 13 | for i in range(cls_num): 14 | # nii = number of pixels of classe i predicted to belong to class i 15 | nii = tf.reduce_sum(tf.round(tf.multiply( 16 | y_true[:, :, :, i], y_pred[:, :, :, i]))) 17 | ti = tf.reduce_sum(y_true[:, :, :, i]) # number of pixels of class i 18 | loc_sum = 0 19 | for j in range(cls_num): 20 | # number of pixels of classe j predicted to belong to class i 21 | nji = tf.reduce_sum(tf.round(tf.multiply( 22 | y_true[:, :, :, j], y_pred[:, :, :, i]))) 23 | loc_sum += nji 24 | result += nii / (ti - nii + loc_sum) 25 | return (1 / nc) * result 26 | 27 | 28 | def mean_accuracy(y_true, y_pred, cls_num=CLS_NUM): 29 | result = 0 30 | nc = tf.cast(tf.shape(y_true)[-1], tf.float32) 31 | for i in range(cls_num): 32 | nii = tf.reduce_sum(tf.round(tf.multiply( 33 | y_true[:, :, :, i], y_pred[:, :, :, i]))) 34 | ti = tf.reduce_sum(y_true[:, :, :, i]) 35 | if ti != 0: 36 | result += (nii / ti) 37 | return (1 / nc) * result 38 | 39 | 40 | def frequency_weighted_iou(y_true, y_pred, cls_num=CLS_NUM): 41 | result = 0 42 | for i in range(cls_num): 43 | nii = tf.reduce_sum(tf.round(tf.multiply( 44 | y_true[:, :, :, i], y_pred[:, :, :, i]))) 45 | ti = tf.reduce_sum(y_true[:, :, :, i]) 46 | loc_sum = 0 47 | for j in range(cls_num): 48 | nji = tf.reduce_sum(tf.round(tf.multiply( 49 | y_true[:, :, :, j], y_pred[:, :, :, i]))) 50 | loc_sum += nji 51 | result += (loc_sum * nii) / (ti - nii + loc_sum) 52 | sum_ti = tf.reduce_sum(y_true[:, :, :, :]) 53 | return (1 / sum_ti) * result 54 | 55 | 56 | def pixel_accuracy(y_true, y_pred): 57 | # nii = number of pixels of classe i predicted to belong to class i 58 | sum_nii = tf.reduce_sum(tf.round(tf.multiply( 59 | y_true[:, :, :, :], y_pred[:, :, :, :]))) 60 | # ti = number of pixels of class i 61 | sum_ti = tf.reduce_sum(y_true[:, :, :, :]) 62 | return sum_nii / sum_ti 63 | 64 | 65 | get_custom_objects().update({ 66 | 'pixel_accuracy': pixel_accuracy, 67 | 'frequency_weighted_iou': frequency_weighted_iou, 68 | 'mean_accuracy': mean_accuracy, 69 | 'mean_iou': mean_iou 70 | }) 71 | -------------------------------------------------------------------------------- /experiments/cal_pca.py: -------------------------------------------------------------------------------- 1 | # color-auto-correlogram 2 | # https://blog.csdn.net/u013066730/article/details/53609859 3 | from __future__ import print_function, division 4 | 5 | import numpy as np 6 | import cv2 as cv 7 | import sys 8 | import os 9 | import tqdm 10 | import time 11 | import csv 12 | import pandas as pd 13 | from sklearn import svm 14 | from sklearn.cluster import KMeans 15 | from sklearn.decomposition import PCA 16 | from sklearn.model_selection import train_test_split 17 | from sklearn.metrics import classification_report 18 | from sklearn.ensemble import RandomForestClassifier 19 | import matplotlib.pyplot as plt 20 | 21 | import logging 22 | logging.basicConfig(level=logging.INFO, 23 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 24 | 25 | from utils import * 26 | from imutils import paths 27 | 28 | 29 | class RGBHistogram(Histogram): 30 | '''RGB Histogram''' 31 | 32 | def __init__(self, bins): 33 | super().__init__(bins) 34 | 35 | def describe(self, image, mask): 36 | image = cv.cvtColor(image, cv.COLOR_BGR2RGB) 37 | hist = cv.calcHist([image], [0, 1, 2], mask, 38 | self.bins, [0, 256, 0, 256, 0, 256]) 39 | hist = hist / np.sum(hist) 40 | 41 | # 512 dimensions 42 | return hist.flatten() 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | logger = logging.getLogger(__name__) 48 | 49 | CLASSES = 5 50 | K_ClUSTER = 15 51 | 52 | images_list = [] 53 | masks_list = [] 54 | features_list = [] 55 | classes_list = [] 56 | 57 | s1 = time.time() 58 | for i in range(0, CLASSES): 59 | for imgpath in sorted(paths.list_images(str(i))): 60 | if os.path.splitext(imgpath)[-1] == '.jpg': 61 | images_list.append(imgpath) 62 | classes_list.append(int(i)) 63 | elif os.path.splitext(imgpath)[-1] == '.png': 64 | masks_list.append(imgpath) 65 | else: 66 | raise ValueError("type error...") 67 | s2 = time.time() 68 | logger.info(f"Time use: {s2 - s1} s") 69 | 70 | hist = RGBHistogram([8, 8, 8]) 71 | 72 | for image_path, mask_path in tqdm.tqdm(zip(images_list, masks_list)): 73 | image = cv.imread(image_path) 74 | mask = cv.imread(mask_path, 0) 75 | features = hist.describe(image, mask) 76 | features_list.append(features) 77 | 78 | logger.info(f"Time use: {time.time() - s2} s") 79 | logger.info("Data process ready...") 80 | 81 | assert len(features_list) == len(classes_list) 82 | 83 | # PCA Dimensionality Reduction 84 | pca = PCA(n_components=K_ClUSTER, random_state=2019) 85 | # pca.fit(features_list) 86 | # logger.info(pca.explained_variance_ratio_) 87 | newX = pca.fit_transform(features_list) 88 | 89 | X_train, X_test, y_train, y_test = train_test_split( 90 | newX, classes_list, test_size=0.2, random_state=2019) 91 | 92 | clf = RandomForestClassifier(n_estimators=180, random_state=2019) 93 | y_pred = clf.fit(X_train, y_train).predict(X_test) 94 | 95 | classify_report = classification_report(y_test, y_pred) 96 | logger.info('\n' + classify_report) 97 | 98 | np.set_printoptions(precision=2) 99 | plot_confusion_matrix(y_test, y_pred, classes=['0', '1', 100 | '2', '3', '4'], title='Confusion matrix') 101 | plt.show() 102 | -------------------------------------------------------------------------------- /model/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2020/3/27 17:43 4 | # @Author : JackyLUO 5 | # @E-mail : lingluo@stumail.neu.edu.cn 6 | # @Site : 7 | # @File : mobilenet.py 8 | # @Software: PyCharm 9 | 10 | from keras.models import * 11 | from keras.layers import * 12 | 13 | 14 | def conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): 15 | filters = int(filters * alpha) 16 | x = ZeroPadding2D(padding=(1, 1), name='conv1_pad')(inputs) 17 | x = Conv2D(filters, kernel, padding='valid', use_bias=False, strides=strides, name='conv1')(x) 18 | x = BatchNormalization(axis=3, name='conv1_bn')(x) 19 | return ReLU(6, name='conv1_relu')(x) 20 | 21 | 22 | def depthwise_conv_block(inputs, pointwise_conv_filters, alpha, depth_multiplier=1, strides=(1, 1), block_id=1): 23 | pointwise_conv_filters = int(pointwise_conv_filters * alpha) 24 | x = ZeroPadding2D((1, 1), name='conv_pad_%d' % block_id)(inputs) 25 | x = DepthwiseConv2D((3, 3), padding='valid', depth_multiplier=depth_multiplier, strides=strides, use_bias=False, 26 | name='conv_dw_%d' % block_id)(x) 27 | x = BatchNormalization(axis=3, name='conv_dw_%d_bn' % block_id)(x) 28 | x = ReLU(6, name='conv_dw_%d_relu' % block_id)(x) 29 | x = Conv2D(pointwise_conv_filters, (1, 1), padding='same', use_bias=False, strides=(1, 1), 30 | name='conv_pw_%d' % block_id)(x) 31 | x = BatchNormalization(axis=3, name='conv_pw_%d_bn' % block_id)(x) 32 | return ReLU(6, name='conv_pw_%d_relu' % block_id)(x) 33 | 34 | 35 | def MobileNet(input_shape, cls_num, alpha=0.5): 36 | inputs = Input(input_shape) 37 | x = conv_block(inputs, 16, alpha, strides=(2, 2)) 38 | x = depthwise_conv_block(x, 16, alpha, 6, block_id=1) 39 | f1 = x 40 | x = depthwise_conv_block(x, 32, alpha, 6, strides=(2, 2), block_id=2) 41 | x = depthwise_conv_block(x, 32, alpha, 6, block_id=3) 42 | f2 = x 43 | x = depthwise_conv_block(x, 64, alpha, 6, strides=(2, 2), block_id=4) 44 | x = depthwise_conv_block(x, 64, alpha, 6, block_id=5) 45 | f3 = x 46 | x = depthwise_conv_block(x, 128, alpha, 6, strides=(2, 2), block_id=6) 47 | x = depthwise_conv_block(x, 128, alpha, 6, block_id=7) 48 | x = depthwise_conv_block(x, 128, alpha, 6, block_id=8) 49 | x = depthwise_conv_block(x, 128, alpha, 6, block_id=9) 50 | x = depthwise_conv_block(x, 128, alpha, 6, block_id=10) 51 | x = depthwise_conv_block(x, 128, alpha, 6, block_id=11) 52 | 53 | o = x 54 | o = Conv2D(128, (3, 3), activation='relu', padding='same')(o) 55 | o = BatchNormalization()(o) 56 | # decode 57 | o = UpSampling2D((2, 2))(o) 58 | o = concatenate([o, f3], axis=-1) 59 | o = Conv2D(64, (3, 3), padding='same')(o) 60 | o = BatchNormalization()(o) 61 | 62 | o = UpSampling2D((2, 2))(o) 63 | o = concatenate([o, f2], axis=-1) 64 | o = Conv2D(32, (3, 3), padding='same')(o) 65 | o = BatchNormalization()(o) 66 | 67 | o = UpSampling2D((2, 2))(o) 68 | o = concatenate([o, f1], axis=-1) 69 | 70 | o = Conv2D(16, (3, 3), padding='same')(o) 71 | o = BatchNormalization()(o) 72 | 73 | o = Conv2D(cls_num, (3, 3), padding='same')(o) 74 | o = UpSampling2D((2, 2))(o) 75 | o = Activation('softmax')(o) 76 | 77 | return Model(inputs, o) 78 | 79 | 80 | if __name__ == '__main__': 81 | from flops import get_flops 82 | 83 | model = MobileNet(input_shape=(256, 256, 3), cls_num=3) 84 | model.summary() 85 | 86 | get_flops(model, True) 87 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | from keras.models import load_model 4 | from keras.applications.imagenet_utils import preprocess_input as pinput 5 | 6 | import cv2 as cv 7 | import numpy as np 8 | import os 9 | import argparse 10 | from metric import * 11 | import glob 12 | from model.fast_scnn import resize_image 13 | from segmentation_models.losses import * 14 | 15 | import warnings 16 | 17 | warnings.filterwarnings('ignore') 18 | 19 | import tensorflow as tf 20 | 21 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 22 | 23 | IMG_SIZE = None 24 | 25 | 26 | def vis_parsing_maps(im, parsing_anno, data_name): 27 | part_colors = [[255, 255, 255], [0, 255, 0], [255, 0, 0]] 28 | 29 | if data_name == 'figaro1k': 30 | part_colors = [[255, 255, 255], [255, 0, 0]] 31 | 32 | im = np.array(im) 33 | vis_im = im.copy().astype(np.uint8) 34 | vis_parsing_anno_color = np.zeros( 35 | (parsing_anno.shape[0], parsing_anno.shape[1], 3)) 36 | 37 | for pi in range(len(part_colors)): 38 | index = np.where(parsing_anno == pi) 39 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 40 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 41 | 42 | # Guided filter 43 | # vis_parsing_anno_color = cv.ximgproc.guidedFilter( 44 | # guide=vis_im, src=vis_parsing_anno_color, radius=4, eps=50, dDepth=-1) 45 | vis_im = cv.addWeighted(vis_im, 0.7, vis_parsing_anno_color, 0.3, 0) 46 | 47 | return vis_im 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--image_size", 54 | help="size of image", type=int, default=256) 55 | parser.add_argument("--model_path", 56 | help="the path of model", type=str, 57 | default='./weights/celebhair/exper/fastscnn/model.h5') 58 | args = parser.parse_args() 59 | 60 | IMG_SIZE = args.image_size 61 | MODEL_PATH = args.model_path 62 | 63 | if MODEL_PATH.split('/')[-2] == 'lednet': 64 | from model.lednet import LEDNet 65 | 66 | model = LEDNet(2, 3, (256, 256, 3)).model() 67 | model.load_weights(MODEL_PATH) 68 | 69 | else: 70 | model = load_model(MODEL_PATH, custom_objects={'mean_accuracy': mean_accuracy, 71 | 'mean_iou': mean_iou, 72 | 'frequency_weighted_iou': frequency_weighted_iou, 73 | 'pixel_accuracy': pixel_accuracy, 74 | 'categorical_crossentropy_plus_dice_loss': cce_dice_loss, 75 | 'resize_image': resize_image}) 76 | 77 | data_name = MODEL_PATH.split('/')[2] 78 | 79 | for img_path in glob.glob(os.path.join("./demo", data_name, "*.jpg")): 80 | img_basename = os.path.basename(img_path) 81 | name = os.path.splitext(img_basename)[0] 82 | 83 | org_img = cv.imread(img_path) 84 | try: 85 | h, w, _ = org_img.shape 86 | except: 87 | raise IOError("Reading image error...") 88 | 89 | img_resize = cv.resize(org_img, (IMG_SIZE, IMG_SIZE)) 90 | img = img_resize[np.newaxis, :] 91 | # pre-processing 92 | img = pinput(img) 93 | 94 | result_map = np.argmax(model.predict(img)[0], axis=-1) 95 | out = vis_parsing_maps(img_resize, result_map, data_name) 96 | out = cv.resize(out, (w, h), interpolation=cv.INTER_NEAREST) 97 | 98 | cv.imwrite(os.path.join("./demo", data_name, "{}.png").format(name), out) 99 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import random 5 | import glob 6 | 7 | from keras.utils import Sequence 8 | from keras.applications.imagenet_utils import preprocess_input as pinput 9 | 10 | 11 | class HairGenerator(Sequence): 12 | 13 | def __init__(self, 14 | transformer, 15 | root_dir, 16 | mode='Training', 17 | nb_classes=3, 18 | batch_size=4, 19 | backbone=None, 20 | shuffle=False): 21 | 22 | # backbone fit for segmentation_models,have been deleted now... 23 | assert mode in ['Training', 'Testing'], "Data set selection error..." 24 | 25 | self.image_path_list = sorted( 26 | glob.glob(os.path.join(root_dir, 'Original', mode, '*'))) 27 | self.mask_path_list = sorted( 28 | glob.glob(os.path.join(root_dir, 'GT', mode, '*'))) 29 | self.transformer = transformer 30 | self.batch_size = batch_size 31 | self.nb_classes = nb_classes 32 | self.shuffle = shuffle 33 | self.mode = mode 34 | self.backbone = backbone 35 | 36 | def __getitem__(self, idx): 37 | images, masks = [], [] 38 | 39 | for (image_path, mask_path) in zip(self.image_path_list[idx * self.batch_size: (idx + 1) * self.batch_size], 40 | self.mask_path_list[idx * self.batch_size: (idx + 1) * self.batch_size]): 41 | image = cv2.imread(image_path, 1) 42 | mask = cv2.imread(mask_path, 0) 43 | 44 | image = self._padding(image) 45 | mask = self._padding(mask) 46 | 47 | # augumentation 48 | augmentation = self.transformer(image=image, mask=mask) 49 | image = augmentation['image'] 50 | mask = self._get_result_map(augmentation['mask']) 51 | 52 | images.append(image) 53 | masks.append(mask) 54 | 55 | images = np.array(images) 56 | masks = np.array(masks) 57 | images = pinput(images) 58 | 59 | return images, masks 60 | 61 | def __len__(self): 62 | """Steps required per epoch""" 63 | return len(self.image_path_list) // self.batch_size 64 | 65 | def _padding(self, image): 66 | shape = image.shape 67 | h, w = shape[:2] 68 | width = np.max([h, w]) 69 | padd_h = (width - h) // 2 70 | padd_w = (width - w) // 2 71 | if len(shape) == 3: 72 | padd_tuple = ((padd_h, width - h - padd_h), 73 | (padd_w, width - w - padd_w), (0, 0)) 74 | else: 75 | padd_tuple = ((padd_h, width - h - padd_h), (padd_w, width - w - padd_w)) 76 | image = np.pad(image, padd_tuple, 'constant') 77 | return image 78 | 79 | def on_epoch_end(self): 80 | """Shuffle image order""" 81 | if self.shuffle: 82 | c = list(zip(self.image_path_list, self.mask_path_list)) 83 | random.shuffle(c) 84 | self.image_path_list, self.mask_path_list = zip(*c) 85 | 86 | def _get_result_map(self, mask): 87 | """Processing mask data""" 88 | 89 | # mask.shape[0]: row, mask.shape[1]: column 90 | result_map = np.zeros((mask.shape[1], mask.shape[0], self.nb_classes)) 91 | # 0 (background pixel), 128 (face area pixel) or 255 (hair area pixel). 92 | skin = (mask == 128) 93 | hair = (mask == 255) 94 | 95 | if self.nb_classes == 2: 96 | # hair = (mask > 128) 97 | background = np.logical_not(hair) 98 | result_map[:, :, 0] = np.where(background, 1, 0) 99 | result_map[:, :, 1] = np.where(hair, 1, 0) 100 | elif self.nb_classes == 3: 101 | background = np.logical_not(hair + skin) 102 | result_map[:, :, 0] = np.where(background, 1, 0) 103 | result_map[:, :, 1] = np.where(skin, 1, 0) 104 | result_map[:, :, 2] = np.where(hair, 1, 0) 105 | else: 106 | raise Exception("error...") 107 | 108 | return result_map 109 | -------------------------------------------------------------------------------- /experiments/cal_moments.py: -------------------------------------------------------------------------------- 1 | # https://www.cnblogs.com/klchang/p/6512310.html 2 | from __future__ import print_function, division 3 | 4 | import cv2 as cv 5 | import numpy as np 6 | import tqdm 7 | import time 8 | import os 9 | import sys 10 | import logging 11 | logging.basicConfig(level=logging.INFO, 12 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 13 | 14 | import matplotlib.pyplot as plt 15 | from sklearn.externals import joblib 16 | from sklearn.metrics import classification_report 17 | from sklearn.model_selection import train_test_split 18 | from sklearn.ensemble import RandomForestClassifier 19 | from imblearn.over_sampling import SMOTE 20 | 21 | from utils import * 22 | from imutils import paths 23 | 24 | 25 | def color_moments(image, mask, color_space): 26 | """ 27 | function: Color Moment Features 28 | image: raw image 29 | mask: image mask 30 | color_space: 'rgb' or 'lab' or 'ycrcb' or 'hsv' 31 | """ 32 | assert image.shape[:2] == mask.shape 33 | assert color_space.lower() in ['lab', 'rgb', 'ycrcb', 'hsv'] 34 | 35 | if color_space.lower() == 'rgb': 36 | image = cv.cvtColor(image, cv.COLOR_BGR2RGB) 37 | elif color_space.lower() == 'hsv': 38 | image = cv.cvtColor(image, cv.COLOR_BGR2HSV) 39 | elif color_space.lower() == 'lab': 40 | image = cv.cvtColor(image, cv.COLOR_BGR2LAB) 41 | elif color_space.lower() == 'ycrcb': 42 | image = cv.cvtColor(image, cv.COLOR_BGR2YCrCb) 43 | else: 44 | raise ValueError("Color space error...") 45 | 46 | # Split image channels info 47 | c1, c2, c3 = cv.split(image) 48 | color_feature = [] 49 | 50 | # Only process mask != 0 channel region 51 | c1 = c1[np.where(mask != 0)] 52 | c2 = c2[np.where(mask != 0)] 53 | c3 = c3[np.where(mask != 0)] 54 | 55 | # Extract mean 56 | mean_1 = np.mean(c1) 57 | mean_2 = np.mean(c2) 58 | mean_3 = np.mean(c3) 59 | 60 | # Extract variance 61 | variance_1 = np.std(c1) 62 | variance_2 = np.std(c2) 63 | variance_3 = np.std(c3) 64 | 65 | # Extract skewness 66 | skewness_1 = np.mean(np.abs(c1 - mean_1) ** 3) ** (1. / 3) 67 | skewness_2 = np.mean(np.abs(c1 - mean_2) ** 3) ** (1. / 3) 68 | skewness_3 = np.mean(np.abs(c1 - mean_3) ** 3) ** (1. / 3) 69 | 70 | color_feature.extend( 71 | [mean_1, mean_2, mean_3, variance_1, variance_2, 72 | variance_3, skewness_1, skewness_2, skewness_3]) 73 | 74 | return color_feature 75 | 76 | 77 | if __name__ == "__main__": 78 | 79 | logger = logging.getLogger(__name__) 80 | 81 | CLASSES = 5 82 | 83 | images_list = [] 84 | masks_list = [] 85 | features_list = [] 86 | classes_list = [] 87 | 88 | s1 = time.time() 89 | for i in range(0, CLASSES): 90 | for imgpath in sorted(paths.list_images(str(i))): 91 | if os.path.splitext(imgpath)[-1] == '.jpg': 92 | images_list.append(imgpath) 93 | classes_list.append(int(i)) 94 | elif os.path.splitext(imgpath)[-1] == '.png': 95 | masks_list.append(imgpath) 96 | else: 97 | raise ValueError("type error...") 98 | s2 = time.time() 99 | logger.info(f"Time use: {s2 - s1} s") 100 | 101 | for image_path, mask_path in tqdm.tqdm(zip(images_list, masks_list)): 102 | image = cv.imread(image_path) 103 | mask = cv.imread(mask_path, 0) 104 | features = color_moments(image, mask, color_space='ycrcb') 105 | features_list.append(features) 106 | 107 | logger.info(f"Time use: {time.time() - s2} s") 108 | logger.info("Data process ready...") 109 | 110 | # Resampling 111 | sm = SMOTE(sampling_strategy='all', random_state=2019) 112 | features_list, classes_list = sm.fit_resample(features_list, classes_list) 113 | 114 | X_train, X_test, y_train, y_test = train_test_split( 115 | features_list, classes_list, test_size=0.2, random_state=2019) 116 | 117 | clf = RandomForestClassifier(n_estimators=180, random_state=2019) 118 | y_pred = clf.fit(X_train, y_train).predict(X_test) 119 | joblib.dump(clf, 'skinColor.pkl') 120 | 121 | classify_report = classification_report(y_test, y_pred) 122 | logger.info('\n' + classify_report) 123 | 124 | np.set_printoptions(precision=2) 125 | plot_confusion_matrix(y_test, y_pred, classes=['0', '1', 126 | '2', '3', '4'], title='Confusion matrix') 127 | plt.show() 128 | -------------------------------------------------------------------------------- /model/flops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2020/3/27 17:49 4 | # @Author : JackyLUO 5 | # @E-mail : lingluo@stumail.neu.edu.cn 6 | # @Site : 7 | # @File : flops.py 8 | # @Software: PyCharm 9 | 10 | # https://github.com/ckyrkou/Keras_FLOP_Estimator 11 | 12 | import keras.backend as K 13 | 14 | 15 | def get_flops(model, table=False): 16 | if table: 17 | print('%25s | %16s | %16s | %16s | %16s | %6s | %6s' % ( 18 | 'Layer Name', 'Input Shape', 'Output Shape', 'Kernel Size', 'Filters', 'Strides', 'FLOPS')) 19 | print('-' * 170) 20 | 21 | t_flops = 0 22 | t_macc = 0 23 | 24 | for l in model.layers: 25 | 26 | o_shape, i_shape, strides, ks, filters = ['', '', ''], ['', '', ''], [1, 1], [0, 0], [0, 0] 27 | flops = 0 28 | macc = 0 29 | name = l.name 30 | 31 | factor = 1e9 32 | 33 | if 'InputLayer' in str(l): 34 | i_shape = l.input.get_shape()[1:4].as_list() 35 | o_shape = i_shape 36 | 37 | if 'Reshape' in str(l): 38 | i_shape = l.input.get_shape()[1:4].as_list() 39 | o_shape = l.output.get_shape()[1:4].as_list() 40 | 41 | if 'Add' in str(l) or 'Maximum' in str(l) or 'Concatenate' in str(l): 42 | i_shape = l.input[0].get_shape()[1:4].as_list() + [len(l.input)] 43 | o_shape = l.output.get_shape()[1:4].as_list() 44 | flops = (len(l.input) - 1) * i_shape[0] * i_shape[1] * i_shape[2] 45 | 46 | if 'Average' in str(l) and 'pool' not in str(l): 47 | i_shape = l.input[0].get_shape()[1:4].as_list() + [len(l.input)] 48 | o_shape = l.output.get_shape()[1:4].as_list() 49 | flops = len(l.input) * i_shape[0] * i_shape[1] * i_shape[2] 50 | 51 | if 'BatchNormalization' in str(l): 52 | i_shape = l.input.get_shape()[1:4].as_list() 53 | o_shape = l.output.get_shape()[1:4].as_list() 54 | 55 | bflops = 1 56 | for i in range(len(i_shape)): 57 | bflops *= i_shape[i] 58 | flops /= factor 59 | 60 | if 'Activation' in str(l) or 'activation' in str(l): 61 | i_shape = l.input.get_shape()[1:4].as_list() 62 | o_shape = l.output.get_shape()[1:4].as_list() 63 | bflops = 1 64 | for i in range(len(i_shape)): 65 | bflops *= i_shape[i] 66 | flops /= factor 67 | 68 | if 'pool' in str(l) and ('Global' not in str(l)): 69 | i_shape = l.input.get_shape()[1:4].as_list() 70 | strides = l.strides 71 | ks = l.pool_size 72 | flops = ((i_shape[0] / strides[0]) * (i_shape[1] / strides[1]) * (ks[0] * ks[1] * i_shape[2])) 73 | 74 | if 'Flatten' in str(l): 75 | i_shape = l.input.shape[1:4].as_list() 76 | flops = 1 77 | out_vec = 1 78 | for i in range(len(i_shape)): 79 | flops *= i_shape[i] 80 | out_vec *= i_shape[i] 81 | o_shape = flops 82 | flops = 0 83 | 84 | if 'Dense' in str(l): 85 | print(l.input) 86 | i_shape = l.input.shape[1:4].as_list()[0] 87 | if i_shape is None: 88 | i_shape = out_vec 89 | 90 | o_shape = l.output.shape[1:4].as_list() 91 | flops = 2 * (o_shape[0] * i_shape) 92 | macc = flops / 2 93 | 94 | if 'Padding' in str(l): 95 | flops = 0 96 | 97 | if 'Global' in str(l): 98 | i_shape = l.input.get_shape()[1:4].as_list() 99 | flops = ((i_shape[0]) * (i_shape[1]) * (i_shape[2])) 100 | o_shape = [l.output.get_shape()[1:4].as_list(), 1, 1] 101 | out_vec = o_shape 102 | 103 | if 'Conv2D' in str(l) and 'DepthwiseConv2D' not in str(l) and 'SeparableConv2D' not in str(l): 104 | strides = l.strides 105 | ks = l.kernel_size 106 | filters = l.filters 107 | # if 'Conv2DTranspose' in str(l): 108 | # i_shape = list(K.int_shape(l.input)[1:4]) 109 | # o_shape = list(K.int_shape(l.output)[1:4]) 110 | # else: 111 | i_shape = l.input.get_shape()[1:4].as_list() 112 | o_shape = l.output.get_shape()[1:4].as_list() 113 | 114 | if filters is None: 115 | filters = i_shape[2] 116 | 117 | flops = 2 * ((filters * ks[0] * ks[1] * i_shape[2]) * ( 118 | (i_shape[0] / strides[0]) * (i_shape[1] / strides[1]))) 119 | macc = flops / 2 120 | 121 | if 'Conv2D' in str(l) and 'DepthwiseConv2D' in str(l) and 'SeparableConv2D' not in str(l): 122 | strides = l.strides 123 | ks = l.kernel_size 124 | filters = l.filters 125 | i_shape = l.input.get_shape()[1:4].as_list() 126 | o_shape = l.output.get_shape()[1:4].as_list() 127 | 128 | if filters is None: 129 | filters = i_shape[2] 130 | 131 | flops = 2 * ((ks[0] * ks[1] * i_shape[2]) * ((i_shape[0] / strides[0]) * ( 132 | i_shape[1] / strides[1]))) / factor 133 | macc = flops / 2 134 | 135 | t_macc += macc 136 | 137 | t_flops += flops 138 | 139 | if table: 140 | print('%25s | %16s | %16s | %16s | %16s | %6s | %5.4f' % ( 141 | name, str(i_shape), str(o_shape), str(ks), str(filters), str(strides), flops)) 142 | t_flops = t_flops / factor 143 | 144 | print('Total FLOPS (x 10^-9): %10.8f G' % (t_flops)) 145 | print('Total MACCs: %10.8f\n' % (t_macc)) 146 | 147 | return 148 | -------------------------------------------------------------------------------- /model/fast_scnn.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import tensorflow as tf 3 | 4 | 5 | def resize_image(image): 6 | return tf.image.resize_images(image, (256, 256)) 7 | 8 | 9 | class Fast_SCNN: 10 | 11 | def __init__(self, num_classes=3, input_shape=(256, 256, 3)): 12 | self.classes = num_classes 13 | self.input_shape = input_shape 14 | self.height = input_shape[0] 15 | self.width = input_shape[1] 16 | 17 | def conv_block(self, inputs, conv_type, kernel, kernel_size, strides, padding='same', relu=True): 18 | if conv_type == 'ds': 19 | x = keras.layers.SeparableConv2D(kernel, kernel_size, padding=padding, strides=strides)(inputs) 20 | else: 21 | x = keras.layers.Conv2D(kernel, kernel_size, padding=padding, strides=strides)(inputs) 22 | 23 | x = keras.layers.BatchNormalization()(x) 24 | 25 | if relu: 26 | x = keras.layers.ReLU()(x) 27 | 28 | return x 29 | 30 | def learning_to_downsample(self): 31 | # Input Layer 32 | self.input_layer = keras.layers.Input(shape=self.input_shape, name='input_layer') 33 | 34 | self.lds_layer = self.conv_block(self.input_layer, 'conv', 32, (3, 3), strides=(2, 2)) 35 | self.lds_layer = self.conv_block(self.lds_layer, 'ds', 48, (3, 3), strides=(2, 2)) 36 | self.lds_layer = self.conv_block(self.lds_layer, 'ds', 64, (3, 3), strides=(2, 2)) 37 | 38 | def global_feature_extractor(self): 39 | self.gfe_layer = self.bottleneck_block(self.lds_layer, 64, (3, 3), t=6, strides=2, n=3) 40 | self.gfe_layer = self.bottleneck_block(self.gfe_layer, 96, (3, 3), t=6, strides=2, n=3) 41 | self.gfe_layer = self.bottleneck_block(self.gfe_layer, 128, (3, 3), t=6, strides=1, n=3) 42 | # self.gfe_layer = self.pyramid_pooling_block(self.gfe_layer, [2, 4, 6, 8]) 43 | self.gfe_layer = self.pyramid_pooling_block(self.gfe_layer, [1, 2, 4]) 44 | 45 | def _res_bottleneck(self, inputs, filters, kernel, t, s, r=False): 46 | tchannel = keras.backend.int_shape(inputs)[-1] * t 47 | 48 | x = self.conv_block(inputs, 'conv', tchannel, (1, 1), strides=(1, 1)) 49 | 50 | x = keras.layers.DepthwiseConv2D(kernel, strides=(s, s), depth_multiplier=1, padding='same')(x) 51 | x = keras.layers.BatchNormalization()(x) 52 | x = keras.layers.ReLU()(x) 53 | 54 | x = self.conv_block(x, 'conv', filters, (1, 1), strides=(1, 1), padding='same', relu=False) 55 | 56 | if r: 57 | x = keras.layers.add([x, inputs]) 58 | return x 59 | 60 | def bottleneck_block(self, inputs, filters, kernel, t, strides, n): 61 | x = self._res_bottleneck(inputs, filters, kernel, t, strides) 62 | 63 | for i in range(1, n): 64 | x = self._res_bottleneck(x, filters, kernel, t, 1, True) 65 | 66 | return x 67 | 68 | def pyramid_pooling_block(self, input_tensor, bin_sizes): 69 | concat_list = [input_tensor] 70 | w = self.width // 32 71 | h = self.height // 32 72 | 73 | for bin_size in bin_sizes: 74 | x = keras.layers.AveragePooling2D(pool_size=(bin_size, bin_size), 75 | strides=(bin_size, bin_size))(input_tensor) 76 | x = keras.layers.Conv2D(128, (3, 3), strides=2, padding='same')(x) 77 | x = keras.layers.BatchNormalization()(x) 78 | x = keras.layers.ReLU()(x) 79 | x = keras.layers.UpSampling2D(size=(bin_size * 2, bin_size * 2))(x) 80 | concat_list.append(x) 81 | 82 | return keras.layers.concatenate(concat_list) 83 | 84 | def feature_fusion(self): 85 | ff_layer1 = self.conv_block(self.lds_layer, 'conv', 128, (1, 1), padding='same', strides=(1, 1), relu=False) 86 | 87 | ff_layer2 = keras.layers.UpSampling2D((4, 4))(self.gfe_layer) 88 | ff_layer2 = keras.layers.DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=1, padding='same')(ff_layer2) 89 | ff_layer2 = keras.layers.BatchNormalization()(ff_layer2) 90 | ff_layer2 = keras.layers.ReLU()(ff_layer2) 91 | ff_layer2 = keras.layers.Conv2D(128, (1, 1), strides=1, padding='same', activation=None)(ff_layer2) 92 | 93 | self.ff_final = keras.layers.add([ff_layer1, ff_layer2]) 94 | self.ff_final = keras.layers.BatchNormalization()(self.ff_final) 95 | self.ff_final = keras.layers.ReLU()(self.ff_final) 96 | 97 | def classifier(self): 98 | self.classifier = keras.layers.SeparableConv2D(128, (3, 3), padding='same', strides=(1, 1), 99 | name='DSConv1_classifier')(self.ff_final) 100 | self.classifier = keras.layers.BatchNormalization()(self.classifier) 101 | self.classifier = keras.layers.ReLU()(self.classifier) 102 | 103 | self.classifier = keras.layers.SeparableConv2D(128, (3, 3), padding='same', strides=(1, 1), 104 | name='DSConv2_classifier')(self.classifier) 105 | self.classifier = keras.layers.BatchNormalization()(self.classifier) 106 | self.classifier = keras.layers.ReLU()(self.classifier) 107 | 108 | self.classifier = self.conv_block(self.classifier, 'conv', self.classes, (1, 1), strides=(1, 1), padding='same', 109 | relu=False) 110 | self.classifier = keras.layers.Lambda(lambda image: resize_image(image), name='Resize')(self.classifier) 111 | self.classifier = keras.layers.Dropout(0.3)(self.classifier) 112 | 113 | def activation(self, activation='softmax'): 114 | x = keras.layers.Activation(activation, 115 | name=activation)(self.classifier) 116 | return x 117 | 118 | def model(self, activation='softmax'): 119 | self.learning_to_downsample() 120 | self.global_feature_extractor() 121 | self.feature_fusion() 122 | self.classifier() 123 | self.output_layer = self.activation(activation) 124 | 125 | model = keras.Model(inputs=self.input_layer, 126 | outputs=self.output_layer, 127 | name='Fast_SCNN') 128 | return model 129 | 130 | 131 | if __name__ == '__main__': 132 | from flops import get_flops 133 | 134 | model = Fast_SCNN(num_classes=3, input_shape=(256, 256, 3)).model() 135 | model.summary() 136 | 137 | get_flops(model) 138 | -------------------------------------------------------------------------------- /experiments/cal_histogram.py: -------------------------------------------------------------------------------- 1 | 2 | # 参考资料: 3 | # https://www.cnblogs.com/maybe2030/p/4585705.html 4 | # https://blog.csdn.net/zhu_hongji/article/details/80443585 5 | # https://blog.csdn.net/wsp_1138886114/article/details/80660014 6 | # https://blog.csdn.net/gfjjggg/article/details/87919658 7 | # https://baike.baidu.com/item/%E9%A2%9C%E8%89%B2%E7%9F%A9/19426187?fr=aladdin 8 | # https://blog.csdn.net/langyuewu/article/details/4144139 9 | from __future__ import print_function, division 10 | 11 | from sklearn import svm 12 | from imblearn.over_sampling import SMOTE 13 | from sklearn.metrics import classification_report, confusion_matrix 14 | from sklearn.externals import joblib 15 | from sklearn.neural_network import MLPClassifier 16 | from sklearn.ensemble import RandomForestClassifier 17 | from sklearn.model_selection import KFold, cross_val_score, train_test_split 18 | import cv2 as cv 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import os 22 | import sys 23 | import time 24 | from imutils import paths 25 | import logging 26 | logging.basicConfig(level=logging.INFO, 27 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 28 | 29 | from utils import * 30 | 31 | 32 | class RGBHistogram(Histogram): 33 | '''RGB Histogram''' 34 | 35 | def __init__(self, bins): 36 | super().__init__(bins) 37 | 38 | def describe(self, image, mask): 39 | hist_b = cv.calcHist([image], [0], mask, self.bins, 40 | [0, 256]) 41 | hist_g = cv.calcHist([image], [1], mask, self.bins, 42 | [0, 256]) 43 | hist_r = cv.calcHist([image], [2], mask, self.bins, 44 | [0, 256]) 45 | hist_b = hist_b / np.sum(hist_b) 46 | hist_g = hist_g / np.sum(hist_g) 47 | hist_r = hist_r / np.sum(hist_r) 48 | 49 | # 24 dimensions 50 | return np.concatenate([hist_b, hist_g, hist_r], axis=0).reshape(-1) 51 | 52 | 53 | 54 | class HSVHistogram(Histogram): 55 | '''HSV Histogram''' 56 | 57 | def __init__(self, bins): 58 | super().__init__(bins) 59 | 60 | def describe(self, image, mask): 61 | image = cv.cvtColor(image, cv.COLOR_BGR2HSV) 62 | hist_h = cv.calcHist([image], [0], mask, self.bins, 63 | [0, 180]) 64 | hist_s = cv.calcHist([image], [1], mask, self.bins, 65 | [0, 256]) 66 | hist_v = cv.calcHist([image], [2], mask, self.bins, 67 | [0, 256]) 68 | hist_h = hist_h / np.sum(hist_h) 69 | hist_s = hist_s / np.sum(hist_s) 70 | hist_v = hist_v / np.sum(hist_v) 71 | 72 | # 24 dimensions 73 | return np.concatenate([hist_h, hist_s, hist_v], axis=0).reshape(-1) 74 | 75 | 76 | class YCrCbHistogram(Histogram): 77 | '''YCrCb Histogram''' 78 | 79 | def __init__(self, bins): 80 | super().__init__(bins) 81 | 82 | def describe(self, image, mask): 83 | image = cv.cvtColor(image, cv.COLOR_BGR2YCrCb) 84 | hist_y = cv.calcHist([image], [0], mask, self.bins, 85 | [0, 256]) 86 | hist_cr = cv.calcHist([image], [1], mask, self.bins, 87 | [0, 256]) 88 | hist_cb = cv.calcHist([image], [2], mask, self.bins, 89 | [0, 256]) 90 | hist_y = hist_y / np.sum(hist_y) 91 | hist_cr = hist_cr / np.sum(hist_cr) 92 | hist_cb = hist_cb / np.sum(hist_cb) 93 | 94 | # 24 dimensions 95 | return np.concatenate([hist_y, hist_cr, hist_cb], axis=0).reshape(-1) 96 | 97 | 98 | if __name__ == "__main__": 99 | 100 | logger = logging.getLogger(__name__) 101 | 102 | CLASSES = 5 103 | 104 | images_list = [] 105 | masks_list = [] 106 | features_list = [] 107 | classes_list = [] 108 | 109 | hist = YCrCbHistogram([8]) 110 | 111 | s1 = time.time() 112 | for i in range(0, CLASSES): 113 | for imgpath in sorted(paths.list_images(str(i))): 114 | if os.path.splitext(imgpath)[-1] == '.jpg': 115 | images_list.append(imgpath) 116 | classes_list.append(i) 117 | elif os.path.splitext(imgpath)[-1] == '.png': 118 | masks_list.append(imgpath) 119 | else: 120 | raise ValueError("type error...") 121 | s2 = time.time() 122 | logger.info(f"Time use: {s2 - s1} s") 123 | 124 | for image_path, mask_path in zip(images_list, masks_list): 125 | # print(image_path, mask_path) 126 | image = cv.imread(image_path) 127 | mask = cv.imread(mask_path, 0) 128 | features = hist.describe(image, mask) 129 | # print(features) 130 | features_list.append(features) 131 | 132 | logger.info(f"Time use: {time.time() - s2} s") 133 | logger.info("Data process ready...") 134 | 135 | # Resampling 136 | sm = SMOTE(sampling_strategy='all', random_state=2019) 137 | features_list, classes_list = sm.fit_resample(features_list, classes_list) 138 | 139 | # Machine learning algorithm 140 | # clf = MLPClassifier(solver='lbfgs', alpha=1e-5, 141 | # hidden_layer_sizes=(8, ), random_state=2019) 142 | clf = RandomForestClassifier(n_estimators=180, random_state=2019) 143 | # kf = KFold(n_splits=CLASSES, random_state=2019, shuffle=True).\ 144 | # get_n_splits(features_list) 145 | # scores = cross_val_score(clf, features_list, classes_list, 146 | # scoring='accuracy', cv=kf) 147 | # score = scores.mean() 148 | # logger.info(f"KFold score: {score}") 149 | 150 | # Split train and test dataset 151 | X_train, X_test, y_train, y_test = train_test_split( 152 | features_list, classes_list, test_size=0.2, random_state=2019) 153 | y_pred = clf.fit(X_train, y_train).predict(X_test) 154 | 155 | classify_report = classification_report(y_test, y_pred) 156 | logger.info('\n' + classify_report) 157 | 158 | np.set_printoptions(precision=2) 159 | plot_confusion_matrix(y_test, y_pred, classes=['0', '1', 160 | '2', '3', '4'], title='Confusion matrix') 161 | plt.show() 162 | 163 | # Save model 164 | # https://blog.csdn.net/qiang12qiang12/article/details/81001839 165 | # How to load model: 166 | # 1. clf = joblib.load('models/histogram.pkl') 167 | # 2. clf.predict(X_test) 168 | 169 | # joblib.dump(clf, 'models/histogram.pkl') 170 | -------------------------------------------------------------------------------- /model/lednet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2020/3/27 17:42 4 | # @Author : JackyLUO 5 | # @E-mail : lingluo@stumail.neu.edu.cn 6 | # @Site : 7 | # @File : lednet.py 8 | # @Software: PyCharm 9 | 10 | from keras import layers, models 11 | import tensorflow as tf 12 | 13 | 14 | class LEDNet: 15 | def __init__(self, groups, classes, input_shape): 16 | self.groups = groups 17 | self.classes = classes 18 | self.input_shape = input_shape 19 | 20 | def ss_bt(self, x, dilation, strides=(1, 1), padding='same'): 21 | x1, x2 = self.channel_split(x) 22 | filters = (int(x.shape[-1]) // self.groups) 23 | x1 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding)(x1) 24 | x1 = layers.Activation('relu')(x1) 25 | x1 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding)(x1) 26 | x1 = layers.BatchNormalization()(x1) 27 | x1 = layers.Activation('relu')(x1) 28 | x1 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding, dilation_rate=(dilation, 1))( 29 | x1) 30 | x1 = layers.Activation('relu')(x1) 31 | x1 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding, dilation_rate=(1, dilation))( 32 | x1) 33 | x1 = layers.BatchNormalization()(x1) 34 | x1 = layers.Activation('relu')(x1) 35 | 36 | x2 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding)(x2) 37 | x2 = layers.Activation('relu')(x2) 38 | x2 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding)(x2) 39 | x2 = layers.BatchNormalization()(x2) 40 | x2 = layers.Activation('relu')(x2) 41 | x2 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding, dilation_rate=(1, dilation))( 42 | x2) 43 | x2 = layers.Activation('relu')(x2) 44 | x2 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding, dilation_rate=(dilation, 1))( 45 | x2) 46 | x2 = layers.BatchNormalization()(x2) 47 | x2 = layers.Activation('relu')(x2) 48 | x_concat = layers.concatenate([x1, x2], axis=-1) 49 | x_add = layers.add([x, x_concat]) 50 | output = self.channel_shuffle(x_add) 51 | return output 52 | 53 | def channel_shuffle(self, x): 54 | n, h, w, c = x.shape.as_list() 55 | x_reshaped = layers.Reshape([h, w, self.groups, int(c // self.groups)])(x) 56 | x_transposed = layers.Permute((1, 2, 4, 3))(x_reshaped) 57 | output = layers.Reshape([h, w, c])(x_transposed) 58 | return output 59 | 60 | def channel_split(self, x): 61 | def splitter(y): 62 | # keras Lambda saving bug!!! 63 | # x_left = layers.Lambda(lambda y: y[:, :, :, :int(int(y.shape[-1]) // self.groups)])(x) 64 | # x_right = layers.Lambda(lambda y: y[:, :, :, int(int(y.shape[-1]) // self.groups):])(x) 65 | # return x_left, x_right 66 | return tf.split(y, num_or_size_splits=self.groups, axis=-1) 67 | 68 | return layers.Lambda(lambda y: splitter(y))(x) 69 | 70 | def down_sample(self, x, filters): 71 | x_filters = int(x.shape[-1]) 72 | x_conv = layers.Conv2D(filters - x_filters, kernel_size=3, strides=(2, 2), padding='same')(x) 73 | x_pool = layers.MaxPool2D()(x) 74 | x = layers.concatenate([x_conv, x_pool], axis=-1) 75 | x = layers.BatchNormalization()(x) 76 | x = layers.Activation('relu')(x) 77 | return x 78 | 79 | def apn_module(self, x): 80 | 81 | def right(x): 82 | x = layers.AveragePooling2D()(x) 83 | x = layers.Conv2D(self.classes, kernel_size=1, padding='same')(x) 84 | x = layers.BatchNormalization()(x) 85 | x = layers.Activation('relu')(x) 86 | x = layers.UpSampling2D(interpolation='bilinear')(x) 87 | return x 88 | 89 | def conv(x, filters, kernel_size, stride): 90 | x = layers.Conv2D(filters, kernel_size=kernel_size, strides=(stride, stride), padding='same')(x) 91 | x = layers.BatchNormalization()(x) 92 | x = layers.Activation('relu')(x) 93 | return x 94 | 95 | x_7 = conv(x, int(x.shape[-1]), 7, stride=2) 96 | x_5 = conv(x_7, int(x.shape[-1]), 5, stride=2) 97 | x_3 = conv(x_5, int(x.shape[-1]), 3, stride=2) 98 | 99 | x_3_1 = conv(x_3, self.classes, 3, stride=1) 100 | x_3_1_up = layers.UpSampling2D(interpolation='bilinear')(x_3_1) 101 | x_5_1 = conv(x_5, self.classes, 5, stride=1) 102 | x_3_5 = layers.add([x_5_1, x_3_1_up]) 103 | x_3_5_up = layers.UpSampling2D(interpolation='bilinear')(x_3_5) 104 | x_7_1 = conv(x_7, self.classes, 3, stride=1) 105 | x_3_5_7 = layers.add([x_7_1, x_3_5_up]) 106 | x_3_5_7_up = layers.UpSampling2D(interpolation='bilinear')(x_3_5_7) 107 | 108 | x_middle = conv(x, self.classes, 1, stride=1) 109 | x_middle = layers.multiply([x_3_5_7_up, x_middle]) 110 | 111 | x_right = right(x) 112 | x_middle = layers.add([x_middle, x_right]) 113 | return x_middle 114 | 115 | def encoder(self, x): 116 | x = self.down_sample(x, filters=32) 117 | for _ in range(3): 118 | x = self.ss_bt(x, dilation=1) 119 | 120 | x = self.down_sample(x, filters=64) 121 | for _ in range(2): 122 | x = self.ss_bt(x, dilation=1) 123 | 124 | x = self.down_sample(x, filters=128) 125 | 126 | dilation_rate = [1, 2, 5, 9, 2, 5, 9, 17] 127 | for dilation in dilation_rate: 128 | x = self.ss_bt(x, dilation=dilation) 129 | return x 130 | 131 | def decoder(self, x): 132 | x = self.apn_module(x) 133 | x = layers.UpSampling2D(size=8, interpolation='bilinear')(x) 134 | x = layers.Conv2D(self.classes, kernel_size=3, padding='same')(x) 135 | x = layers.BatchNormalization()(x) 136 | x = layers.Activation('softmax')(x) 137 | return x 138 | 139 | def model(self): 140 | inputs = layers.Input(shape=self.input_shape) 141 | encoder_out = self.encoder(inputs) 142 | output = self.decoder(encoder_out) 143 | return models.Model(inputs, output) 144 | 145 | 146 | if __name__ == '__main__': 147 | from flops import get_flops 148 | 149 | model = LEDNet(2, 3, (256, 256, 3)).model() 150 | model.summary() 151 | 152 | get_flops(model) 153 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from data_loader import HairGenerator 3 | from keras.callbacks import ModelCheckpoint, CSVLogger, TensorBoard, LearningRateScheduler 4 | import os 5 | import warnings 6 | from keras import optimizers 7 | from keras.regularizers import l2 8 | from metric import * 9 | from segmentation_models.losses import * 10 | import numpy as np 11 | 12 | from albumentations import * 13 | from model.hlnet import HLNet 14 | from model.dfanet import DFANet 15 | from model.enet import ENet 16 | from model.lednet import LEDNet 17 | from model.mobilenet import MobileNet 18 | from model.fast_scnn import Fast_SCNN 19 | 20 | warnings.filterwarnings("ignore") 21 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 22 | 23 | import tensorflow as tf 24 | 25 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--batch_size", '-b', 29 | help="batch size", type=int, default=64) 30 | parser.add_argument("--image_size", '-i', 31 | help="image size", type=int, default=256) 32 | parser.add_argument("--backbone", '-bb', 33 | help="backbone of the network", type=str, default=None) 34 | parser.add_argument("--epoches", '-e', help="epoch size", 35 | type=int, default=150) 36 | parser.add_argument("--model_name", help="model's name", 37 | choices=['hlnet', 'fastscnn', 'lednet', 'dfanet', 'enet', 'mobilenet'], 38 | type=str, default='hlnet') 39 | parser.add_argument("--learning_rate", help="learning rate", type=float, default=2.5e-3) 40 | parser.add_argument("--checkpoints", 41 | help="where is the checkpoint", type=str, default='./weights') 42 | parser.add_argument("--class_number", 43 | help="number of output", type=int, default=3) 44 | parser.add_argument("--data_dir", 45 | help="path of dataset", type=str, default='./data/CelebA') 46 | args = parser.parse_args() 47 | 48 | 49 | def get_model(name): 50 | if name == 'hlnet': 51 | model = HLNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) 52 | elif name == 'fastscnn': 53 | model = Fast_SCNN(num_classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model() 54 | elif name == 'lednet': 55 | model = LEDNet(groups=2, classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model() 56 | elif name == 'dfanet': 57 | model = DFANet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM, size_factor=2) 58 | elif name == 'enet': 59 | model = ENet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) 60 | elif name == 'mobilenet': 61 | model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) 62 | else: 63 | raise NameError("No corresponding model...") 64 | 65 | return model 66 | 67 | 68 | class PolyDecay: 69 | '''Exponential decay strategy implementation''' 70 | 71 | def __init__(self, initial_lr, power, n_epochs): 72 | self.initial_lr = initial_lr 73 | self.power = power 74 | self.n_epochs = n_epochs 75 | 76 | def scheduler(self, epoch): 77 | return self.initial_lr * np.power(1.0 - 1.0 * epoch / self.n_epochs, self.power) 78 | 79 | 80 | def set_regularization(model, 81 | kernel_regularizer=None, 82 | bias_regularizer=None, 83 | activity_regularizer=None): 84 | '''Parameter regularization processing to prevent model overfitting''' 85 | for layer in model.layers: 86 | if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'): 87 | layer.kernel_regularizer = kernel_regularizer 88 | 89 | if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'): 90 | layer.bias_regularizer = bias_regularizer 91 | 92 | if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'): 93 | layer.activity_regularizer = activity_regularizer 94 | 95 | 96 | def main(): 97 | config = tf.ConfigProto() 98 | config.gpu_options.allow_growth = True 99 | session = tf.Session(config=config) 100 | 101 | global IMG_SIZE 102 | global CLS_NUM 103 | 104 | ROOT_DIR = args.data_dir 105 | BACKBONE = args.backbone 106 | BATCH_SIZE = args.batch_size 107 | IMG_SIZE = args.image_size 108 | EPOCHS = args.epoches 109 | LR = args.learning_rate 110 | CHECKPOINT = args.checkpoints 111 | CLS_NUM = args.class_number 112 | MODEL_NAME = args.model_name 113 | 114 | train_transformer = Compose([ # GaussNoise(p=0.2), 115 | ShiftScaleRotate( 116 | shift_limit=0.1, scale_limit=0.2, rotate_limit=20, p=0.5), 117 | HorizontalFlip(p=0.5), 118 | # HueSaturationValue(p=0.5), 119 | # RandomBrightnessContrast(0.5), 120 | # GridDistortion(distort_limit=0.2, p=0.5), 121 | Resize(height=IMG_SIZE, width=IMG_SIZE, always_apply=True), 122 | ]) 123 | val_transformer = Compose( 124 | [Resize(height=IMG_SIZE, width=IMG_SIZE, always_apply=True)]) 125 | 126 | train_generator = HairGenerator( 127 | train_transformer, ROOT_DIR, mode='Training', batch_size=BATCH_SIZE, nb_classes=CLS_NUM, 128 | backbone=BACKBONE, shuffle=True) 129 | 130 | val_generator = HairGenerator( 131 | val_transformer, ROOT_DIR, mode='Testing', batch_size=BATCH_SIZE, nb_classes=CLS_NUM, 132 | backbone=BACKBONE) 133 | 134 | # Loading models 135 | model = get_model(MODEL_NAME) 136 | set_regularization(model, kernel_regularizer=l2(2e-5)) 137 | model.compile(optimizer=optimizers.SGD(lr=LR, momentum=0.98), 138 | loss=cce_dice_loss, metrics=[mean_iou, frequency_weighted_iou, mean_accuracy, pixel_accuracy]) 139 | 140 | CHECKPOINT = CHECKPOINT + '/' + MODEL_NAME 141 | if not os.path.exists(CHECKPOINT): 142 | os.makedirs(CHECKPOINT) 143 | 144 | checkpoint = ModelCheckpoint(filepath=os.path.join(CHECKPOINT, 'model-{epoch:03d}.h5'), 145 | monitor='val_loss', 146 | save_best_only=True, 147 | verbose=1) 148 | tensorboard = TensorBoard(log_dir=os.path.join(CHECKPOINT, 'logs')) 149 | csvlogger = CSVLogger( 150 | os.path.join(CHECKPOINT, "result.csv")) 151 | 152 | lr_decay = LearningRateScheduler(PolyDecay(LR, 0.9, EPOCHS).scheduler) 153 | 154 | model.fit_generator( 155 | train_generator, 156 | len(train_generator), 157 | validation_data=val_generator, 158 | validation_steps=len(val_generator), 159 | epochs=EPOCHS, 160 | verbose=1, 161 | callbacks=[checkpoint, tensorboard, csvlogger, lr_decay] 162 | ) 163 | 164 | K.clear_session() 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /pipline_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | from keras.models import load_model 4 | import numpy as np 5 | import time 6 | import cv2 as cv 7 | import os 8 | import sys 9 | import argparse 10 | from sklearn.externals import joblib 11 | import matplotlib.pyplot as plt 12 | from keras.applications.imagenet_utils import preprocess_input as pinput 13 | from keras import backend as K 14 | 15 | import tensorflow as tf 16 | tf.logging.set_verbosity(tf.logging.ERROR) 17 | 18 | from segmentation_models.backbones import get_preprocessing 19 | from model.hlnet import HLRNet 20 | from model.hrnet import HRNet 21 | from segmentation_models import PSPNet, Unet, FPN, Linknet 22 | from mtcnn.mtcnn import MTCNN 23 | from metric import * 24 | from imutils import paths 25 | 26 | IMG_SIZE = None 27 | 28 | 29 | def color_moments(image, mask, color_space): 30 | """ 31 | function: Color Moment Features 32 | image: raw image 33 | mask: image mask 34 | color_space: 'rgb' or 'lab' or 'ycrcb' or 'hsv' 35 | """ 36 | assert image.shape[:2] == mask.shape 37 | assert color_space.lower() in ['lab', 'rgb', 'ycrcb', 'hsv'] 38 | 39 | if color_space.lower() == 'rgb': 40 | image = cv.cvtColor(image, cv.COLOR_BGR2RGB) 41 | elif color_space.lower() == 'hsv': 42 | image = cv.cvtColor(image, cv.COLOR_BGR2HSV) 43 | elif color_space.lower() == 'lab': 44 | image = cv.cvtColor(image, cv.COLOR_BGR2LAB) 45 | elif color_space.lower() == 'ycrcb': 46 | image = cv.cvtColor(image, cv.COLOR_BGR2YCrCb) 47 | else: 48 | raise ValueError("Color space error...") 49 | 50 | # Split image channels info 51 | c1, c2, c3 = cv.split(image) 52 | color_feature = [] 53 | 54 | # Only process mask != 0 channel region 55 | c1 = c1[np.where(mask != 0)] 56 | c2 = c2[np.where(mask != 0)] 57 | c3 = c3[np.where(mask != 0)] 58 | 59 | # Extract mean 60 | mean_1 = np.mean(c1) 61 | mean_2 = np.mean(c2) 62 | mean_3 = np.mean(c3) 63 | 64 | # Extract variance 65 | variance_1 = np.std(c1) 66 | variance_2 = np.std(c2) 67 | variance_3 = np.std(c3) 68 | 69 | # Extract skewness 70 | skewness_1 = np.mean(np.abs(c1 - mean_1) ** 3) ** (1. / 3) 71 | skewness_2 = np.mean(np.abs(c1 - mean_2) ** 3) ** (1. / 3) 72 | skewness_3 = np.mean(np.abs(c1 - mean_3) ** 3) ** (1. / 3) 73 | 74 | color_feature.extend( 75 | [mean_1, mean_2, mean_3, variance_1, variance_2, 76 | variance_3, skewness_1, skewness_2, skewness_3]) 77 | 78 | return color_feature 79 | 80 | 81 | def _result_map_toimg(result_map): 82 | '''show result map''' 83 | img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) 84 | 85 | argmax_id = np.argmax(result_map, axis=-1) 86 | background = (argmax_id == 0) 87 | skin = (argmax_id == 1) 88 | hair = (argmax_id == 2) 89 | 90 | img[:, :, 0] = np.where(background, 255, 0) 91 | img[:, :, 1] = np.where(skin, 255, 0) 92 | img[:, :, 2] = np.where(hair, 255, 0) 93 | 94 | return img 95 | 96 | 97 | def imcrop(img, x1, y1, x2, y2): 98 | if x1 < 0 or y1 < 0 or x2 > img.shape[1] or y2 > img.shape[0]: 99 | img, x1, x2, y1, y2 = pad_img_to_fit_bbox(img, x1, x2, y1, y2) 100 | return img[y1:y2, x1:x2, :] 101 | 102 | 103 | def pad_img_to_fit_bbox(img, x1, x2, y1, y2): 104 | img = cv.copyMakeBorder(img, - min(0, y1), max(y2 - img.shape[0], 0), 105 | -min(0, x1), max(x2 - img.shape[1], 0), cv.BORDER_REPLICATE) 106 | y2 += -min(0, y1) 107 | y1 += -min(0, y1) 108 | x2 += -min(0, x1) 109 | x1 += -min(0, x1) 110 | return img, x1, x2, y1, y2 111 | 112 | 113 | if __name__ == '__main__': 114 | 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("--image_size", '-is', 117 | help="size of image", type=int, default=224) 118 | parser.add_argument("--backbone", '-bb', 119 | help="backbone of image", type=str, default='seresnet18') 120 | parser.add_argument("--model_path", '-mp', 121 | help="the path of model", type=str, 122 | default='./checkpoints/CelebA/HLNet/model-222-0.159.h5') 123 | parser.add_argument("--margin", 124 | help="margin of image", type=float, default=0.3) 125 | parser.add_argument('--use_design', action='store_false') 126 | args = parser.parse_args() 127 | 128 | IMG_SIZE = args.image_size 129 | MODEL_PATH = args.model_path 130 | BACKBONE = args.backbone 131 | USE_DESIGN = args.use_design 132 | 133 | detector = MTCNN() 134 | clf = joblib.load('./experiments/skinGrade/skinColor.pkl') 135 | model = load_model(MODEL_PATH, custom_objects={'mean_accuracy': mean_accuracy, 136 | 'mean_iou': mean_iou, 137 | 'frequency_weighted_iou': frequency_weighted_iou, 138 | 'pixel_accuracy': pixel_accuracy}) 139 | colorHue = ['Ivory white', 'Porcelain white', 140 | 'natural color', 'Yellowish', 'Black'] 141 | 142 | for img_path in paths.list_images("./data/Testing"): 143 | t = time.time() 144 | 145 | org_img = cv.imread(img_path) 146 | try: 147 | org_img.shape 148 | except: 149 | raise ValueError("Reading image error...") 150 | 151 | org_img_rgb = org_img[:, :, ::-1] # RGB 152 | detected = detector.detect_faces(org_img_rgb) 153 | 154 | if len(detected) != 1: 155 | print("[INFO] multi faces or no face...") 156 | continue 157 | 158 | d = detected[0]['box'] 159 | x1, y1, x2, y2, w, h = d[0], d[1], d[0] + d[2], d[1] + d[3], d[2], d[3] 160 | xw1 = int(x1 - args.margin * w) 161 | yw1 = int(y1 - args.margin * h) 162 | xw2 = int(x2 + args.margin * w) 163 | yw2 = int(y2 + args.margin * h) 164 | cropped_img = imcrop(org_img, xw1, yw1, xw2, yw2) 165 | o_h, o_w, _ = cropped_img.shape 166 | 167 | cropped_img_resize = cv.resize(cropped_img, (IMG_SIZE, IMG_SIZE)) 168 | img = cropped_img_resize[np.newaxis, :] 169 | 170 | 171 | # only subtract mean value 172 | img = pinput(img) 173 | 174 | result_map = model.predict(img)[0] 175 | mask = _result_map_toimg(result_map) 176 | mask = cv.resize(mask, (o_w, o_h)) 177 | 178 | # Face channel 179 | mask_face = mask[:, :, 1] 180 | features = color_moments(cropped_img, mask_face, color_space='ycrcb') 181 | features = np.array(features, np.float32)[np.newaxis, :] 182 | skinHue = colorHue[clf.predict(features)[0]] 183 | 184 | cv.rectangle(org_img, (x1, y1), (x2, y2), (255, 0, 0), 2) 185 | cv.putText(org_img, 'Color: {}'.format(skinHue), (x1, y1+30), 186 | cv.FONT_HERSHEY_PLAIN, 1, (0, 255, 0), 1) 187 | print(time.time() - t) # testing time 188 | cv.imshow("image", org_img) 189 | cv.waitKey(-1) 190 | 191 | -------------------------------------------------------------------------------- /model/enet.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2020/3/27 17:41 4 | # @Author : JackyLUO 5 | # @E-mail : lingluo@stumail.neu.edu.cn 6 | # @Site : 7 | # @File : enet.py 8 | # @Software: PyCharm 9 | # 10 | from keras.layers import * 11 | from keras.models import Model 12 | 13 | 14 | class Conv2DTransposeCustom(object): 15 | """Fixed output shape bug...""" 16 | 17 | def __init__(self, filters, kernel_size, strides=(1, 1), padding='same'): 18 | self.filters = filters 19 | self.kernel_size = kernel_size 20 | self.strides = strides 21 | self.padding = padding 22 | 23 | def __call__(self, layer): 24 | out = Conv2DTranspose(self.filters, self.kernel_size, strides=self.strides, padding=self.padding)(layer) 25 | if not isinstance(self.strides, tuple): 26 | self.strides = (self.strides, self.strides) 27 | out.set_shape((out.shape[0], layer.shape[1] * self.strides[0], layer.shape[2] * self.strides[1], out.shape[3])) 28 | return out 29 | 30 | 31 | def initial_block(inp, nb_filter=13, nb_row=3, nb_col=3, strides=(2, 2)): 32 | conv = Conv2D(nb_filter, (nb_row, nb_col), padding='same', strides=strides)(inp) 33 | max_pool = MaxPooling2D()(inp) 34 | merged = concatenate([conv, max_pool], axis=3) 35 | return merged 36 | 37 | 38 | def bottleneck(inp, output, internal_scale=4, asymmetric=0, dilated=0, downsample=False, dropout_rate=0.1): 39 | # main branch 40 | internal = output // internal_scale 41 | encoder = inp 42 | # 1x1 43 | input_stride = 2 if downsample else 1 # the 1st 1x1 projection is replaced with a 2x2 convolution when downsampling 44 | encoder = Conv2D(internal, (input_stride, input_stride), 45 | # padding='same', 46 | strides=(input_stride, input_stride), use_bias=False)(encoder) 47 | # Batch normalization + PReLU 48 | encoder = BatchNormalization(momentum=0.1)(encoder) # enet uses momentum of 0.1, keras default is 0.99 49 | encoder = PReLU(shared_axes=[1, 2])(encoder) 50 | 51 | # conv 52 | if not asymmetric and not dilated: 53 | encoder = Conv2D(internal, (3, 3), padding='same')(encoder) 54 | elif asymmetric: 55 | encoder = Conv2D(internal, (1, asymmetric), padding='same', use_bias=False)(encoder) 56 | encoder = Conv2D(internal, (asymmetric, 1), padding='same')(encoder) 57 | elif dilated: 58 | encoder = Conv2D(internal, (3, 3), dilation_rate=(dilated, dilated), padding='same')(encoder) 59 | else: 60 | raise (Exception('You shouldn\'t be here')) 61 | 62 | encoder = BatchNormalization(momentum=0.1)(encoder) # enet uses momentum of 0.1, keras default is 0.99 63 | encoder = PReLU(shared_axes=[1, 2])(encoder) 64 | 65 | # 1x1 66 | encoder = Conv2D(output, (1, 1), use_bias=False)(encoder) 67 | 68 | encoder = BatchNormalization(momentum=0.1)(encoder) # enet uses momentum of 0.1, keras default is 0.99 69 | encoder = SpatialDropout2D(dropout_rate)(encoder) 70 | 71 | other = inp 72 | # other branch 73 | if downsample: 74 | other = MaxPooling2D()(other) 75 | 76 | other = Permute((1, 3, 2))(other) 77 | pad_feature_maps = output - inp.get_shape().as_list()[3] 78 | tb_pad = (0, 0) 79 | lr_pad = (0, pad_feature_maps) 80 | other = ZeroPadding2D(padding=(tb_pad, lr_pad))(other) 81 | other = Permute((1, 3, 2))(other) 82 | 83 | encoder = add([encoder, other]) 84 | encoder = PReLU(shared_axes=[1, 2])(encoder) 85 | 86 | return encoder 87 | 88 | 89 | def en_build(inp, dropout_rate=0.01): 90 | enet = initial_block(inp) 91 | enet = BatchNormalization(momentum=0.1)(enet) # enet_unpooling uses momentum of 0.1, keras default is 0.99 92 | enet = PReLU(shared_axes=[1, 2])(enet) 93 | enet = bottleneck(enet, 64, downsample=True, dropout_rate=dropout_rate) # bottleneck 1.0 94 | for _ in range(4): 95 | enet = bottleneck(enet, 64, dropout_rate=dropout_rate) # bottleneck 1.i 96 | 97 | enet = bottleneck(enet, 128, downsample=True) # bottleneck 2.0 98 | # bottleneck 2.x and 3.x 99 | for _ in range(2): 100 | enet = bottleneck(enet, 128) # bottleneck 2.1 101 | enet = bottleneck(enet, 128, dilated=2) # bottleneck 2.2 102 | enet = bottleneck(enet, 128, asymmetric=5) # bottleneck 2.3 103 | enet = bottleneck(enet, 128, dilated=4) # bottleneck 2.4 104 | enet = bottleneck(enet, 128) # bottleneck 2.5 105 | enet = bottleneck(enet, 128, dilated=8) # bottleneck 2.6 106 | enet = bottleneck(enet, 128, asymmetric=5) # bottleneck 2.7 107 | enet = bottleneck(enet, 128, dilated=16) # bottleneck 2.8 108 | 109 | return enet 110 | 111 | 112 | # decoder 113 | def de_bottleneck(encoder, output, upsample=False, reverse_module=False): 114 | internal = output // 4 115 | 116 | x = Conv2D(internal, (1, 1), use_bias=False)(encoder) 117 | x = BatchNormalization(momentum=0.1)(x) 118 | x = Activation('relu')(x) 119 | if not upsample: 120 | x = Conv2D(internal, (3, 3), padding='same', use_bias=True)(x) 121 | else: 122 | x = Conv2DTransposeCustom(filters=internal, kernel_size=(3, 3), strides=(2, 2), padding='same')(x) 123 | x = BatchNormalization(momentum=0.1)(x) 124 | x = Activation('relu')(x) 125 | 126 | x = Conv2D(output, (1, 1), padding='same', use_bias=False)(x) 127 | 128 | other = encoder 129 | if encoder.get_shape()[-1] != output or upsample: 130 | other = Conv2D(output, (1, 1), padding='same', use_bias=False)(other) 131 | other = BatchNormalization(momentum=0.1)(other) 132 | if upsample and reverse_module is not False: 133 | other = UpSampling2D(size=(2, 2))(other) 134 | 135 | if upsample and reverse_module is False: 136 | decoder = x 137 | else: 138 | x = BatchNormalization(momentum=0.1)(x) 139 | decoder = add([x, other]) 140 | decoder = Activation('relu')(decoder) 141 | 142 | return decoder 143 | 144 | 145 | def de_build(encoder, nc): 146 | enet = de_bottleneck(encoder, 64, upsample=True, reverse_module=True) # bottleneck 4.0 147 | enet = de_bottleneck(enet, 64) # bottleneck 4.1 148 | enet = de_bottleneck(enet, 64) # bottleneck 4.2 149 | enet = de_bottleneck(enet, 16, upsample=True, reverse_module=True) # bottleneck 5.0 150 | enet = de_bottleneck(enet, 16) # bottleneck 5.1 151 | 152 | enet = Conv2DTransposeCustom(filters=nc, kernel_size=(2, 2), strides=(2, 2), padding='same')(enet) 153 | return enet 154 | 155 | 156 | def ENet(input_shape, cls_num=3): 157 | # Make sure the dimensions are multiples of 32 158 | assert input_shape[0] % 32 == 0 159 | assert input_shape[1] % 32 == 0 160 | img_input = Input(input_shape) 161 | enet = en_build(img_input) 162 | enet = de_build(enet, cls_num) 163 | enet = Activation('softmax')(enet) 164 | return Model(img_input, enet) 165 | 166 | 167 | if __name__ == '__main__': 168 | from flops import get_flops 169 | 170 | model = ENet(input_shape=(256, 256, 3), cls_num=3) 171 | # model.summary() 172 | 173 | get_flops(model, True) 174 | -------------------------------------------------------------------------------- /model/dfanet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2020/3/27 19:56 4 | # @Author : JackyLUO 5 | # @E-mail : lingluo@stumail.neu.edu.cn 6 | # @Site : 7 | # @File : dfanet.py 8 | # @Software: PyCharm 9 | 10 | from keras.layers import * 11 | from keras.models import Model 12 | import keras.backend as K 13 | 14 | 15 | def ConvBlock(inputs, n_filters, kernel_size=3, strides=1): 16 | """ 17 | Basic conv block for Encoder-Decoder 18 | Apply successivly Convolution, BatchNormalization, ReLU nonlinearity 19 | """ 20 | net = Conv2D(n_filters, kernel_size, strides=strides, 21 | padding='same', 22 | kernel_initializer='he_normal', 23 | use_bias=False)(inputs) 24 | 25 | net = BatchNormalization()(net) 26 | net = Activation('relu')(net) 27 | return net 28 | 29 | 30 | def separable_res_block_deep(inputs, nb_filters, filter_size=3, strides=1, dilation=1, ix=0): 31 | inputs = Activation('relu')(inputs) # , name=prefix + '_sepconv1_act' 32 | 33 | ip_nb_filter = K.get_variable_shape(inputs)[-1] 34 | if ip_nb_filter != nb_filters or strides != 1: 35 | residual = Conv2D(nb_filters, 1, strides=strides, use_bias=False)(inputs) 36 | residual = BatchNormalization()(residual) 37 | else: 38 | residual = inputs 39 | 40 | x = SeparableConv2D(nb_filters // 4, filter_size, 41 | dilation_rate=dilation, 42 | padding='same', 43 | use_bias=False, 44 | kernel_initializer='he_normal', 45 | )(inputs) 46 | x = BatchNormalization()(x) # name=prefix + '_sepconv1_bn' 47 | 48 | x = Activation('relu')(x) # , name=prefix + '_sepconv2_act' 49 | x = SeparableConv2D(nb_filters // 4, filter_size, 50 | dilation_rate=dilation, 51 | padding='same', 52 | use_bias=False, 53 | kernel_initializer='he_normal', 54 | )(x) 55 | x = BatchNormalization()(x) # name=prefix + '_sepconv2_bn' 56 | x = Activation('relu')(x) # , name=prefix + '_sepconv3_act' 57 | # if strides != 1: 58 | x = SeparableConv2D(nb_filters, filter_size, 59 | strides=strides, 60 | dilation_rate=dilation, 61 | padding='same', 62 | use_bias=False, 63 | )(x) 64 | 65 | x = BatchNormalization()(x) # name=prefix + '_sepconv3_bn' 66 | x = add([x, residual]) 67 | return x 68 | 69 | 70 | def encoder(inputs, nb_filters, stage): 71 | rep_nums = 0 72 | if stage == 2 or stage == 4: 73 | rep_nums = 4 74 | elif stage == 3: 75 | rep_nums = 6 76 | x = separable_res_block_deep(inputs, nb_filters, strides=2) # , ix = rand_nb + stage * 10 77 | for i in range(rep_nums - 1): 78 | x = separable_res_block_deep(x, nb_filters, strides=1) # , ix = rand_nb + stage * 10 + i 79 | 80 | return x 81 | 82 | 83 | def AttentionRefinementModule(inputs): 84 | # Global average pooling 85 | nb_channels = K.get_variable_shape(inputs)[-1] 86 | net = GlobalAveragePooling2D()(inputs) 87 | 88 | net = Reshape((1, nb_channels))(net) 89 | net = Conv1D(nb_channels, kernel_size=1, 90 | kernel_initializer='he_normal', 91 | )(net) 92 | net = BatchNormalization()(net) 93 | net = Activation('relu')(net) 94 | net = Conv1D(nb_channels, kernel_size=1, 95 | kernel_initializer='he_normal', 96 | )(net) 97 | net = BatchNormalization()(net) 98 | net = Activation('sigmoid')(net) # tf.sigmoid(net) 99 | 100 | net = Multiply()([inputs, net]) 101 | 102 | return net 103 | 104 | 105 | def xception_backbone(inputs, size_factor=2): 106 | x = Conv2D(8, kernel_size=3, strides=2, 107 | padding='same', use_bias=False)(inputs) 108 | x = BatchNormalization()(x) 109 | x = Activation('relu')(x) 110 | 111 | x = encoder(x, int(16 * size_factor), 2) 112 | x = encoder(x, int(32 * size_factor), 3) 113 | x = encoder(x, int(64 * size_factor), 4) 114 | 115 | x = AttentionRefinementModule(x) 116 | return x 117 | 118 | 119 | def DFANet(input_shape, cls_num=3, size_factor=2): 120 | img_input = Input(input_shape) 121 | 122 | x = Conv2D(8, kernel_size=5, strides=2, 123 | padding='same', use_bias=False)(img_input) 124 | x = BatchNormalization()(x) 125 | levela_input = Activation('relu')(x) 126 | 127 | enc2_a = encoder(levela_input, int(16 * size_factor), 2) 128 | 129 | enc3_a = encoder(enc2_a, int(32 * size_factor), 3) 130 | 131 | enc4_a = encoder(enc3_a, int(64 * size_factor), 4) 132 | 133 | enc_attend_a = AttentionRefinementModule(enc4_a) 134 | 135 | enc_upsample_a = UpSampling2D(size=4, interpolation='bilinear')(enc_attend_a) 136 | 137 | levelb_input = Concatenate()([enc2_a, enc_upsample_a]) 138 | enc2_b = encoder(levelb_input, int(16 * size_factor), 2) 139 | 140 | enc2_b_combine = Concatenate()([enc3_a, enc2_b]) 141 | enc3_b = encoder(enc2_b_combine, int(32 * size_factor), 3) 142 | 143 | enc3_b_combine = Concatenate()([enc4_a, enc3_b]) 144 | enc4_b = encoder(enc3_b_combine, int(64 * size_factor), 4) 145 | 146 | enc_attend_b = AttentionRefinementModule(enc4_b) 147 | 148 | enc_upsample_b = UpSampling2D(size=4, interpolation='bilinear')(enc_attend_b) 149 | 150 | levelc_input = Concatenate()([enc2_b, enc_upsample_b]) 151 | enc2_c = encoder(levelc_input, int(16 * size_factor), 2) 152 | 153 | enc2_c_combine = Concatenate()([enc3_b, enc2_c]) 154 | enc3_c = encoder(enc2_c_combine, int(32 * size_factor), 3) 155 | 156 | enc3_c_combine = Concatenate()([enc4_b, enc3_c]) 157 | enc4_c = encoder(enc3_c_combine, int(64 * size_factor), 4) 158 | 159 | enc_attend_c = AttentionRefinementModule(enc4_c) 160 | 161 | enc2_a_decoder = ConvBlock(enc2_a, 32, kernel_size=1) 162 | 163 | enc2_b_decoder = ConvBlock(enc2_b, 32, kernel_size=1) 164 | enc2_b_decoder = UpSampling2D(size=2, interpolation='bilinear')(enc2_b_decoder) 165 | 166 | enc2_c_decoder = ConvBlock(enc2_c, 32, kernel_size=1) 167 | enc2_c_decoder = UpSampling2D(size=4, interpolation='bilinear')(enc2_c_decoder) 168 | 169 | decoder_front = Add()([enc2_a_decoder, enc2_b_decoder, enc2_c_decoder]) 170 | decoder_front = ConvBlock(decoder_front, 32, kernel_size=1) 171 | 172 | att_a_decoder = ConvBlock(enc_attend_a, 32, kernel_size=1) 173 | att_a_decoder = UpSampling2D(size=4, interpolation='bilinear')(att_a_decoder) 174 | 175 | att_b_decoder = ConvBlock(enc_attend_b, 32, kernel_size=1) 176 | att_b_decoder = UpSampling2D(size=8, interpolation='bilinear')(att_b_decoder) 177 | 178 | att_c_decoder = ConvBlock(enc_attend_c, 32, kernel_size=1) 179 | att_c_decoder = UpSampling2D(size=16, interpolation='bilinear')(att_c_decoder) 180 | 181 | decoder_combine = Add()([decoder_front, att_a_decoder, att_b_decoder, att_c_decoder]) 182 | 183 | decoder_combine = ConvBlock(decoder_combine, cls_num * 2, kernel_size=1) 184 | 185 | decoder_final = UpSampling2D(size=4, interpolation='bilinear')(decoder_combine) 186 | output = Conv2D(cls_num, (1, 1), activation='softmax')(decoder_final) 187 | 188 | return Model(img_input, output, name='DFAnet') 189 | 190 | 191 | if __name__ == '__main__': 192 | from flops import get_flops 193 | 194 | model = DFANet(input_shape=(256, 256, 3), cls_num=3, size_factor=2) 195 | model.summary() 196 | 197 | get_flops(model) 198 | -------------------------------------------------------------------------------- /model/hlnet.py: -------------------------------------------------------------------------------- 1 | # Fast-SCNN 2 | # HRNet 3 | # MobileNetv2-v3 4 | # ASPP 5 | from keras.layers import * 6 | from keras.models import Model 7 | from keras.utils import plot_model 8 | 9 | import keras.backend as K 10 | 11 | 12 | def _conv_block(inputs, filters, kernel, strides=1, padding='same', use_activation=False): 13 | """Convolution Block 14 | This function defines a 2D convolution operation with BN and relu. 15 | # Arguments 16 | inputs: Tensor, input tensor of conv layer. 17 | filters: Integer, the dimensionality of the output space. 18 | kernel: An integer or tuple/list of 2 integers, specifying the 19 | width and height of the 2D convolution window. 20 | strides: An integer or tuple/list of 2 integers, 21 | specifying the strides of the convolution along the width and height. 22 | Can be a single integer to specify the same value for 23 | all spatial dimensions. 24 | # Returns 25 | Output tensor. 26 | """ 27 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 28 | 29 | x = Conv2D(filters, kernel, padding=padding, strides=strides, 30 | use_bias=False)(inputs) 31 | x = BatchNormalization(axis=channel_axis)(x) 32 | 33 | if use_activation: 34 | x = Activation('relu')(x) 35 | 36 | return x 37 | 38 | 39 | def _bottleneck(inputs, filters, kernel, t, s, r=False): 40 | """Bottleneck 41 | This function defines a basic bottleneck structure. 42 | # Arguments 43 | inputs: Tensor, input tensor of conv layer. 44 | filters: Integer, the dimensionality of the output space. 45 | kernel: An integer or tuple/list of 2 integers, specifying the 46 | width and height of the 2D convolution window. 47 | t: Integer, expansion factor. 48 | t is always applied to the input size. 49 | s: An integer or tuple/list of 2 integers,specifying the strides 50 | of the convolution along the width and height.Can be a single 51 | integer to specify the same value for all spatial dimensions. 52 | r: Boolean, Whether to use the residuals. 53 | # Returns 54 | Output tensor. 55 | """ 56 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 57 | tchannel = K.int_shape(inputs)[channel_axis] * t 58 | 59 | x = _conv_block(inputs, tchannel, (1, 1)) 60 | 61 | x = DepthwiseConv2D(kernel, strides=( 62 | s, s), depth_multiplier=1, padding='same')(x) 63 | x = BatchNormalization(axis=channel_axis)(x) 64 | # relu6 65 | x = ReLU(max_value=6)(x) 66 | 67 | x = Conv2D(filters, (1, 1), strides=(1, 1), padding='same')(x) 68 | x = BatchNormalization(axis=channel_axis)(x) 69 | 70 | if r: 71 | x = add([x, inputs]) 72 | return x 73 | 74 | 75 | def _inverted_residual_block(inputs, filters, kernel, t, strides, n): 76 | """Inverted Residual Block 77 | This function defines a sequence of 1 or more identical layers. 78 | # Arguments 79 | inputs: Tensor, input tensor of conv layer. 80 | filters: Integer, the dimensionality of the output space. 81 | kernel: An integer or tuple/list of 2 integers, specifying the 82 | width and height of the 2D convolution window. 83 | t: Integer, expansion factor. 84 | t is always applied to the input size. 85 | s: An integer or tuple/list of 2 integers,specifying the strides 86 | of the convolution along the width and height.Can be a single 87 | integer to specify the same value for all spatial dimensions. 88 | n: Integer, layer repeat times. 89 | # Returns 90 | Output tensor. 91 | """ 92 | x = _bottleneck(inputs, filters, kernel, t, strides) 93 | 94 | for i in range(1, n): 95 | x = _bottleneck(x, filters, kernel, t, 1, True) 96 | 97 | return x 98 | 99 | 100 | def _depthwise_separable_block(inputs, kernel, strides, padding='same', depth_multiplier=1): 101 | '''Depth separable point convolution module''' 102 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 103 | 104 | x = DepthwiseConv2D(kernel_size=kernel, strides=strides, padding=padding, 105 | depth_multiplier=depth_multiplier)(inputs) 106 | x = BatchNormalization(axis=channel_axis)(x) 107 | return Activation('relu')(x) 108 | 109 | 110 | def HLNet(input_shape, cls_num=3): 111 | """Higt-Low Resolution Information fusion Network""" 112 | # input_shape: input image shape 113 | # cls_num: output class number 114 | inputs = Input(input_shape) 115 | # Step 1: Feature dimension drops to 1/4 116 | x = _conv_block(inputs, 32, (3, 3), strides=2, use_activation=True) 117 | x = _depthwise_separable_block(x, (3, 3), strides=2, depth_multiplier=2) 118 | x = _depthwise_separable_block(x, (3, 3), strides=2) 119 | 120 | # step 2: 121 | x21 = _inverted_residual_block( 122 | x, 64, kernel=(3, 3), t=6, strides=1, n=3 123 | ) 124 | x22 = _inverted_residual_block( 125 | x, 96, kernel=(3, 3), t=6, strides=2, n=3 126 | ) 127 | x23 = _inverted_residual_block( 128 | x, 128, kernel=(3, 3), t=6, strides=4, n=3 129 | ) 130 | 131 | # step 3: 132 | x31_t1 = x21 133 | x31_t2 = UpSampling2D(interpolation='bilinear')( 134 | _conv_block(x22, 64, (1, 1), use_activation=True)) 135 | x31_t3 = UpSampling2D(size=(4, 4), interpolation='bilinear')( 136 | _conv_block(x23, 64, (1, 1), use_activation=True)) 137 | x31 = Add()([x31_t1, x31_t2, x31_t3]) 138 | 139 | x32_t1 = _conv_block(x21, 96, (1, 1), strides=2, use_activation=True) 140 | x32_t2 = _conv_block(x22, 96, (1, 1), use_activation=True) 141 | x32_t3 = UpSampling2D(interpolation='bilinear')( 142 | _conv_block(x23, 96, (1, 1), use_activation=True)) 143 | x32 = Add()([x32_t1, x32_t2, x32_t3]) 144 | 145 | x33_t1 = _conv_block(x21, 128, (1, 1), strides=4, use_activation=True) 146 | x33_t2 = _conv_block(x22, 128, (1, 1), strides=2, use_activation=True) 147 | x33_t3 = _conv_block(x23, 128, (1, 1), use_activation=True) 148 | x33 = Add()([x33_t1, x33_t2, x33_t3]) 149 | 150 | # step 4: 151 | x41 = _conv_block(x33, 96, (1, 1)) 152 | x42 = UpSampling2D(interpolation='bilinear')(x41) 153 | x43 = Concatenate()([x42, x32]) 154 | x44 = _conv_block(x43, 64, (1, 1)) 155 | x45 = UpSampling2D(interpolation='bilinear')(x44) 156 | x46 = Concatenate()([x45, x31]) 157 | 158 | # step 5: FFM module in BiSeNet 159 | x50 = _conv_block(x46, 64, (3, 3)) 160 | x51 = AveragePooling2D(pool_size=(1, 1))(x50) 161 | x52 = Conv2D(64, (1, 1), use_bias=False, activation='relu')(x51) 162 | x53 = Conv2D(64, (1, 1), use_bias=False, activation='sigmoid')(x52) 163 | x54 = Multiply()([x53, x50]) 164 | x55 = Add()([x50, x54]) 165 | 166 | # step6: 167 | x61 = Conv2D(32, (3, 3), padding='same', dilation_rate=2)(x55) 168 | x62 = Conv2D(32, (3, 3), padding='same', dilation_rate=4)(x55) 169 | x63 = Conv2D(32, (3, 3), padding='same', dilation_rate=8)(x55) 170 | x64 = Add()([x61, x62, x63]) 171 | # x61 = _conv_block(x62, cls_num, (1, 1), use_activation=False) 172 | x65 = UpSampling2D(size=(8, 8), interpolation='bilinear')(x64) 173 | x66 = _conv_block(x65, cls_num, (1, 1), use_activation=False) 174 | out = Activation('softmax')(x66) 175 | 176 | return Model(inputs, out) 177 | 178 | 179 | if __name__ == "__main__": 180 | from flops import get_flops 181 | 182 | # Testing network design 183 | model = HLNet(input_shape=(256, 256, 3), cls_num=3) 184 | model.summary() 185 | 186 | get_flops(model) 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------