├── FER+vsFER.png ├── LICENSE.md ├── src ├── models.py ├── generate_training_data.py ├── img_util.py ├── rect_util.py ├── train.py └── ferplus.py └── README.md /FER+vsFER.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/FERPlus/master/FER+vsFER.png -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | FER+ 2 | 3 | Copyright (c) Microsoft Corporation 4 | 5 | All rights reserved. 6 | 7 | MIT License 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 13 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Microsoft. All rights reserved. 3 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 4 | # 5 | 6 | import os 7 | import sys 8 | import math 9 | import numpy as np 10 | import cntk as ct 11 | 12 | def build_model(num_classes, model_name): 13 | ''' 14 | Factory function to instantiate the model. 15 | ''' 16 | model = getattr(sys.modules[__name__], model_name) 17 | return model(num_classes) 18 | 19 | class VGG13(object): 20 | ''' 21 | A VGG13 like model (https://arxiv.org/pdf/1409.1556.pdf) tweaked for emotion data. 22 | ''' 23 | @property 24 | def learning_rate(self): 25 | return 0.05 26 | 27 | @property 28 | def input_width(self): 29 | return 64 30 | 31 | @property 32 | def input_height(self): 33 | return 64 34 | 35 | @property 36 | def input_channels(self): 37 | return 1 38 | 39 | @property 40 | def model(self): 41 | return None 42 | 43 | @property 44 | def model(self): 45 | return self._model 46 | 47 | def __init__(self, num_classes): 48 | self._model = self._create_model(num_classes) 49 | 50 | def _create_model(self, num_classes): 51 | with ct.default_options(activation=ct.relu, init=ct.glorot_uniform()): 52 | model = ct.layers.Sequential([ 53 | ct.layers.For(range(2), lambda i: [ 54 | ct.layers.Convolution((3,3), [64,128][i], pad=True, name='conv{}-1'.format(i+1)), 55 | ct.layers.Convolution((3,3), [64,128][i], pad=True, name='conv{}-2'.format(i+1)), 56 | ct.layers.MaxPooling((2,2), strides=(2,2), name='pool{}-1'.format(i+1)), 57 | ct.layers.Dropout(0.25, name='drop{}-1'.format(i+1)) 58 | ]), 59 | ct.layers.For(range(2), lambda i: [ 60 | ct.layers.Convolution((3,3), [256,256][i], pad=True, name='conv{}-1'.format(i+3)), 61 | ct.layers.Convolution((3,3), [256,256][i], pad=True, name='conv{}-2'.format(i+3)), 62 | ct.layers.Convolution((3,3), [256,256][i], pad=True, name='conv{}-3'.format(i+3)), 63 | ct.layers.MaxPooling((2,2), strides=(2,2), name='pool{}-1'.format(i+3)), 64 | ct.layers.Dropout(0.25, name='drop{}-1'.format(i+3)) 65 | ]), 66 | ct.layers.For(range(2), lambda i: [ 67 | ct.layers.Dense(1024, activation=None, name='fc{}'.format(i+5)), 68 | ct.layers.Activation(activation=ct.relu, name='relu{}'.format(i+5)), 69 | ct.layers.Dropout(0.5, name='drop{}'.format(i+5)) 70 | ]), 71 | ct.layers.Dense(num_classes, activation=None, name='output') 72 | ]) 73 | return model 74 | -------------------------------------------------------------------------------- /src/generate_training_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Microsoft. All rights reserved. 3 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 4 | # 5 | 6 | import os 7 | import csv 8 | import argparse 9 | import numpy as np 10 | from itertools import islice 11 | from PIL import Image 12 | 13 | # List of folders for training, validation and test. 14 | folder_names = {'Training' : 'FER2013Train', 15 | 'PublicTest' : 'FER2013Valid', 16 | 'PrivateTest': 'FER2013Test'} 17 | 18 | def str_to_image(image_blob): 19 | ''' Convert a string blob to an image object. ''' 20 | image_string = image_blob.split(' ') 21 | image_data = np.asarray(image_string, dtype=np.uint8).reshape(48,48) 22 | return Image.fromarray(image_data) 23 | 24 | def main(base_folder, fer_path, ferplus_path): 25 | ''' 26 | Generate PNG image files from the combined fer2013.csv and fer2013new.csv file. The generated files 27 | are stored in their corresponding folder for the trainer to use. 28 | 29 | Args: 30 | base_folder(str): The base folder that contains 'FER2013Train', 'FER2013Valid' and 'FER2013Test' 31 | subfolder. 32 | fer_path(str): The full path of fer2013.csv file. 33 | ferplus_path(str): The full path of fer2013new.csv file. 34 | ''' 35 | 36 | print("Start generating ferplus images.") 37 | 38 | for key, value in folder_names.items(): 39 | folder_path = os.path.join(base_folder, value) 40 | if not os.path.exists(folder_path): 41 | os.makedirs(folder_path) 42 | 43 | ferplus_entries = [] 44 | with open(ferplus_path,'r') as csvfile: 45 | ferplus_rows = csv.reader(csvfile, delimiter=',') 46 | for row in islice(ferplus_rows, 1, None): 47 | ferplus_entries.append(row) 48 | 49 | index = 0 50 | with open(fer_path,'r') as csvfile: 51 | fer_rows = csv.reader(csvfile, delimiter=',') 52 | for row in islice(fer_rows, 1, None): 53 | ferplus_row = ferplus_entries[index] 54 | file_name = ferplus_row[1].strip() 55 | if len(file_name) > 0: 56 | image = str_to_image(row[1]) 57 | image_path = os.path.join(base_folder, folder_names[row[2]], file_name) 58 | image.save(image_path, compress_level=0) 59 | index += 1 60 | 61 | print("Done...") 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("-d", 66 | "--base_folder", 67 | type = str, 68 | help = "Base folder containing the training, validation and testing folder.", 69 | required = True) 70 | parser.add_argument("-fer", 71 | "--fer_path", 72 | type = str, 73 | help = "Path to the original fer2013.csv file.", 74 | required = True) 75 | 76 | parser.add_argument("-ferplus", 77 | "--ferplus_path", 78 | type = str, 79 | help = "Path to the new fer2013new.csv file.", 80 | required = True) 81 | 82 | args = parser.parse_args() 83 | main(args.base_folder, args.fer_path, args.ferplus_path) -------------------------------------------------------------------------------- /src/img_util.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Microsoft. All rights reserved. 3 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 4 | # 5 | 6 | import numpy as np 7 | import random as rnd 8 | from PIL import Image 9 | from scipy import ndimage 10 | from rect_util import Rect 11 | 12 | def compute_norm_mat(base_width, base_height): 13 | # normalization matrix used in image pre-processing 14 | x = np.arange(base_width) 15 | y = np.arange(base_height) 16 | X, Y = np.meshgrid(x, y) 17 | X = X.flatten() 18 | Y = Y.flatten() 19 | A = np.array([X*0+1, X, Y]).T 20 | A_pinv = np.linalg.pinv(A) 21 | return A, A_pinv 22 | 23 | def preproc_img(img, A, A_pinv): 24 | # compute image histogram 25 | img_flat = img.flatten() 26 | img_hist = np.bincount(img_flat, minlength = 256) 27 | 28 | # cumulative distribution function 29 | cdf = img_hist.cumsum() 30 | cdf = cdf * (2.0 / cdf[-1]) - 1.0 # normalize 31 | 32 | # histogram equalization 33 | img_eq = cdf[img_flat] 34 | 35 | diff = img_eq - np.dot(A, np.dot(A_pinv, img_eq)) 36 | 37 | # after plane fitting, the mean of diff is already 0 38 | std = np.sqrt(np.dot(diff,diff)/diff.size) 39 | if std > 1e-6: 40 | diff = diff/std 41 | return diff.reshape(img.shape) 42 | 43 | def distort_img(img, roi, out_width, out_height, max_shift, max_scale, max_angle, max_skew, flip=True): 44 | shift_y = out_height*max_shift*rnd.uniform(-1.0,1.0) 45 | shift_x = out_width*max_shift*rnd.uniform(-1.0,1.0) 46 | 47 | # rotation angle 48 | angle = max_angle*rnd.uniform(-1.0,1.0) 49 | 50 | #skew 51 | sk_y = max_skew*rnd.uniform(-1.0, 1.0) 52 | sk_x = max_skew*rnd.uniform(-1.0, 1.0) 53 | 54 | # scale 55 | scale_y = rnd.uniform(1.0, max_scale) 56 | if rnd.choice([True, False]): 57 | scale_y = 1.0/scale_y 58 | scale_x = rnd.uniform(1.0, max_scale) 59 | if rnd.choice([True, False]): 60 | scale_x = 1.0/scale_x 61 | T_im = crop_img(img, roi, out_width, out_height, shift_x, shift_y, scale_x, scale_y, angle, sk_x, sk_y) 62 | if flip and rnd.choice([True, False]): 63 | T_im = np.fliplr(T_im) 64 | return T_im 65 | 66 | def crop_img(img, roi, crop_width, crop_height, shift_x, shift_y, scale_x, scale_y, angle, skew_x, skew_y): 67 | # current face center 68 | ctr_in = np.array((roi.center().y, roi.center().x)) 69 | ctr_out = np.array((crop_height/2.0+shift_y, crop_width/2.0+shift_x)) 70 | out_shape = (crop_height, crop_width) 71 | s_y = scale_y*(roi.height()-1)*1.0/(crop_height-1) 72 | s_x = scale_x*(roi.width()-1)*1.0/(crop_width-1) 73 | 74 | # rotation and scale 75 | ang = angle*np.pi/180.0 76 | transform = np.array([[np.cos(ang), -np.sin(ang)], [np.sin(ang), np.cos(ang)]]) 77 | transform = transform.dot(np.array([[1.0, skew_y], [0.0, 1.0]])) 78 | transform = transform.dot(np.array([[1.0, 0.0], [skew_x, 1.0]])) 79 | transform = transform.dot(np.diag([s_y, s_x])) 80 | offset = ctr_in-ctr_out.dot(transform) 81 | 82 | # each point p in the output image is transformed to pT+s, where T is the matrix and s is the offset 83 | T_im = ndimage.interpolation.affine_transform(input = img, 84 | matrix = np.transpose(transform), 85 | offset = offset, 86 | output_shape = out_shape, 87 | order = 1, # bilinear interpolation 88 | mode = 'reflect', 89 | prefilter = False) 90 | return T_im 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FER+ 2 | The FER+ annotations provide a set of new labels for the standard Emotion FER dataset. In FER+, each image has been labeled by 10 crowd-sourced taggers, which provide better quality ground truth for still image emotion than the original FER labels. Having 10 taggers for each image enables researchers to estimate an emotion probability distribution per face. This allows constructing algorithms that produce statistical distributions or multi-label outputs instead of the conventional single-label output, as described in: https://arxiv.org/abs/1608.01041 3 | 4 | Here are some examples of the FER vs FER+ labels extracted from the abovementioned paper (FER top, FER+ bottom): 5 | 6 | ![FER vs FER+ example](https://raw.githubusercontent.com/Microsoft/FERPlus/master/FER+vsFER.png) 7 | 8 | The new label file is named **_fer2013new.csv_** and contains the same number of rows as the original **_fer2013.csv_** label file with the same order, so that you infer which emotion tag belongs to which image. Since we can't host the actual image content, please find the original FER data set here: https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge/data 9 | 10 | The format of the CSV file is as follows: usage, neutral, happiness, surprise, sadness, anger, disgust, fear, contempt, unknown, NF. Columns "usage" is the same as the original FER label to differentiate between training, public test and private test sets. The other columns are the vote count for each emotion with the addition of unknown and NF (Not a Face). 11 | 12 | ## Training 13 | We also provide a training code with implementation for all the training modes (majority, probability, cross entropy and multi-label) described in https://arxiv.org/abs/1608.01041. The training code uses MS Cognitive Toolkit (formerly CNTK) available in: https://github.com/Microsoft/CNTK. 14 | 15 | After installing Cognitive Toolkit and downloading the dataset (we will discuss the dataset layout next), you can simply run the following to start the training: 16 | 17 | #### For majority voting mode 18 | ``` 19 | python train.py -d -m majority 20 | ``` 21 | 22 | #### For probability mode 23 | ``` 24 | python train.py -d -m probability 25 | ``` 26 | 27 | #### For cross entropy mode 28 | ``` 29 | python train.py -d -m crossentropy 30 | ``` 31 | 32 | #### For multi-target mode 33 | ``` 34 | python train.py -d -m multi_target 35 | ``` 36 | 37 | ## FER+ layout for Training 38 | There is a folder named data that has the following layout: 39 | 40 | ``` 41 | /data 42 | /FER2013Test 43 | label.csv 44 | /FER2013Train 45 | label.csv 46 | /FER2013Valid 47 | label.csv 48 | ``` 49 | *label.csv* in each folder contains the actual label for each image, the image name is in the following format: ferXXXXXXXX.png, where XXXXXXXX is the row index of the original FER csv file. So here the names of the first few images: 50 | 51 | ``` 52 | fer0000000.png 53 | fer0000001.png 54 | fer0000002.png 55 | fer0000003.png 56 | ``` 57 | The folders don't contain the actual images, you will need to download them from https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge/data, then extract the images from the FER csv file in such a way, that all images corresponding to "Training" go to FER2013Train folder, all images corresponding to "PublicTest" go to FER2013Valid folder and all images corresponding to "PrivateTest" go to FER2013Test folder. Or you can use `generate_training_data.py` script to do all the above for you as mentioned in next section. 58 | 59 | ### Training data 60 | We provide a simple script `generate_training_data.py` in python that takes **_fer2013.csv_** and **_fer2013new.csv_** as inputs, merge both CSV files and export all the images into a png files for the trainer to process. 61 | 62 | ``` 63 | python generate_training_data.py -d -fer -ferplus 64 | ``` 65 | 66 | # Citation 67 | If you use the new FER+ label or the sample code or part of it in your research, please cite the following: 68 | 69 | **@inproceedings{BarsoumICMI2016, 70 |     title={Training Deep Networks for Facial Expression Recognition with Crowd-Sourced Label Distribution}, 71 |     author={Barsoum, Emad and Zhang, Cha and Canton Ferrer, Cristian and Zhang, Zhengyou}, 72 |     booktitle={ACM International Conference on Multimodal Interaction (ICMI)}, 73 |     year={2016} 74 | }** 75 | -------------------------------------------------------------------------------- /src/rect_util.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Microsoft. All rights reserved. 3 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 4 | # 5 | 6 | import math 7 | 8 | class Point(object): 9 | def __init__(self, x=0.0, y=0.0): 10 | self.x = x 11 | self.y = y 12 | 13 | def __add__(self, p): 14 | """Point(x1+x2, y1+y2)""" 15 | return Point(self.x+p.x, self.y+p.y) 16 | 17 | def __sub__(self, p): 18 | """Point(x1-x2, y1-y2)""" 19 | return Point(self.x-p.x, self.y-p.y) 20 | 21 | def __mul__( self, scalar ): 22 | """Point(x1*x2, y1*y2)""" 23 | return Point(self.x*scalar, self.y*scalar) 24 | 25 | def __div__(self, scalar): 26 | """Point(x1/x2, y1/y2)""" 27 | return Point(self.x/scalar, self.y/scalar) 28 | 29 | def __str__(self): 30 | return "(%s, %s)" % (self.x, self.y) 31 | 32 | def length(self): 33 | return math.sqrt(self.x**2 + self.y**2) 34 | 35 | def distance_to(self, p): 36 | """Calculate the distance between two points.""" 37 | return (self - p).length() 38 | 39 | def as_tuple(self): 40 | """(x, y)""" 41 | return (self.x, self.y) 42 | 43 | def clone(self): 44 | """Return a full copy of this point.""" 45 | return Point(self.x, self.y) 46 | 47 | def integerize(self): 48 | """Convert co-ordinate values to integers.""" 49 | self.x = int(self.x+0.5) 50 | self.y = int(self.y+0.5) 51 | 52 | def floatize(self): 53 | """Convert co-ordinate values to floats.""" 54 | self.x = float(self.x) 55 | self.y = float(self.y) 56 | 57 | def reset(self, x, y): 58 | """Reset x & y coordinates.""" 59 | self.x = x 60 | self.y = y 61 | 62 | def shift(self, pt): 63 | """Move to new (x+pt.x,y+pt.y).""" 64 | self.x = self.x + pt.x 65 | self.y = self.y + pt.y 66 | 67 | def shift_xy(self, dx, dy): 68 | """Move to new (x+dx,y+dy).""" 69 | self.x = self.x + dx 70 | self.y = self.y + dy 71 | 72 | def rotate(self, rad): 73 | """Rotate counter-clockwise by rad radians. 74 | Positive y goes *up,* as in traditional mathematics. 75 | The new position is returned as a new Point. 76 | """ 77 | s, c = [f(rad) for f in (math.sin, math.cos)] 78 | x, y = (c*self.x - s*self.y, s*self.x + c*self.y) 79 | return Point(x,y) 80 | 81 | def rotate_about(self, p, theta): 82 | """Rotate counter-clockwise around a point, by theta degrees. 83 | Positive y goes *up,* as in traditional mathematics. 84 | The new position is returned as a new Point. 85 | """ 86 | result = self.clone() 87 | result.shift(-p.x, -p.y) 88 | result.rotate(theta) 89 | result.shift(p.x, p.y) 90 | return result 91 | 92 | class Rect(object): 93 | """The rectangle stores left, top, right, and bottom values. 94 | Coordinates are based on screen coordinates. 95 | origin top 96 | +-----> x increases | 97 | | left -+- right 98 | v | 99 | y increases bottom 100 | """ 101 | 102 | def __init__(self, box): 103 | """Initialize a rectangle from two points.""" 104 | self.left = box[0] 105 | self.top = box[1] 106 | self.right = box[2] 107 | self.bottom = box[3] 108 | 109 | def as_tuple(self): 110 | """(left, top, right, bottom)""" 111 | return (self.left, self.top, self.right, self.bottom) 112 | 113 | def width(self): 114 | """Width""" 115 | return (self.right - self.left) 116 | 117 | def height(self): 118 | """Height""" 119 | return (self.bottom - self.top) 120 | 121 | def contains(self, pt): 122 | """Return true if a point is inside the rectangle.""" 123 | x,y = pt.as_tuple() 124 | return (self.left <= x <= self.right and 125 | self.top <= y <= self.bottom) 126 | 127 | def shift(self, pt): 128 | """Shift by pt.x and pt.y.""" 129 | self.left = self.left + pt.x 130 | self.right = self.right + pt.x 131 | self.top = self.top + pt.y 132 | self.bottom = self.bottom + pt.y 133 | 134 | def shift_xy(self, dx, dy): 135 | """Shift by dx and dy.""" 136 | self.left = self.left + dx 137 | self.right = self.right + dx 138 | self.top = self.top + dy 139 | self.bottom = self.bottom + dy 140 | 141 | def equal(self, other): 142 | """Return true if a rectangle is identical to this rectangle.""" 143 | return (self.right == other.left and self.left == other.right and 144 | self.top == other.bottom and self.bottom == other.top) 145 | 146 | def overlaps(self, other): 147 | """Return true if a rectangle overlaps this rectangle.""" 148 | return (self.right > other.left and self.left < other.right and 149 | self.top < other.bottom and self.bottom > other.top) 150 | 151 | def intersect(self, other): 152 | """Return the intersect rectangle. 153 | Note we don't check here whether the intersection is valid 154 | If needed, call overlaps() first to check 155 | """ 156 | return Rect((max(self.left, other.left), 157 | max(self.top, other.top), 158 | min(self.right, other.right), 159 | min(self.bottom, other.bottom))) 160 | 161 | def clamp(self, xmin, ymin, xmax, ymax): 162 | """Return clamped rectangle based on the other rectangle. 163 | Note we don't check here whether the output is valid 164 | If needed, call overlaps() first to check 165 | """ 166 | self.left = max(self.left, xmin) 167 | self.right = min(self.right, xmax) 168 | self.top = max(self.top, ymin) 169 | self.bottom = min(self.bottom, ymax) 170 | 171 | def top_left(self): 172 | """Return the top-left corner as a Point.""" 173 | return Point(self.left, self.top) 174 | 175 | def bottom_right(self): 176 | """Return the bottom-right corner as a Point.""" 177 | return Point(self.right, self.bottom) 178 | 179 | def center(self): 180 | """Return the center as a Point.""" 181 | return Point((self.left+self.right)/2.0, (self.top+self.bottom)/2.0) 182 | 183 | def mult(self, xmul, ymul): 184 | """Return a rectangle with all coordinates multipled by a number.""" 185 | return Rect((self.left*xmul, self.top*ymul, self.right*xmul, self.bottom*ymul)) 186 | 187 | def scale(self, scale): 188 | """Return a scaled rectangle with identical center.""" 189 | xctr = (self.left + self.right)/2.0 190 | yctr = (self.top + self.bottom)/2.0 191 | width = self.width()*scale 192 | height = self.height()*scale 193 | xstart = xctr-width/2.0 194 | ystart = yctr-height/2.0 195 | return Rect((xstart, ystart, xstart+width, ystart+height)) 196 | 197 | def cocenter(self, new_width, new_height): 198 | """Return a new rectangle with identical center.""" 199 | xctr = (self.left + self.right)/2.0 200 | yctr = (self.top + self.bottom)/2.0 201 | xstart = xctr - new_width/2.0 202 | ystart = yctr - new_height/2.0 203 | return Rect((xstart, ystart, xstart+new_width, ystart+new_height)) 204 | 205 | def integerize(self): 206 | """Convert co-ordinate values to integers.""" 207 | self.left = int(self.left+0.5) 208 | self.right = int(self.right+0.5) 209 | self.top = int(self.top+0.5) 210 | self.bottom = int(self.bottom+0.5) 211 | 212 | def floatize(self): 213 | """Convert co-ordinate values to floats.""" 214 | self.left = float(self.left) 215 | self.right = float(self.right) 216 | self.top = float(self.top) 217 | self.bottom = float(self.bottom) 218 | 219 | def __str__( self ): 220 | return "" % (self.left,self.top, 221 | self.right,self.bottom) 222 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Microsoft. All rights reserved. 3 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 4 | # 5 | 6 | import sys 7 | import time 8 | import os 9 | import math 10 | import csv 11 | import argparse 12 | import numpy as np 13 | import logging 14 | 15 | from models import * 16 | from ferplus import * 17 | 18 | import cntk as ct 19 | 20 | emotion_table = {'neutral' : 0, 21 | 'happiness': 1, 22 | 'surprise' : 2, 23 | 'sadness' : 3, 24 | 'anger' : 4, 25 | 'disgust' : 5, 26 | 'fear' : 6, 27 | 'contempt' : 7} 28 | 29 | # List of folders for training, validation and test. 30 | train_folders = ['FER2013Train'] 31 | valid_folders = ['FER2013Valid'] 32 | test_folders = ['FER2013Test'] 33 | 34 | def cost_func(training_mode, prediction, target): 35 | ''' 36 | We use cross entropy in most mode, except for the multi-label mode, which require treating 37 | multiple labels exactly the same. 38 | ''' 39 | train_loss = None 40 | if training_mode == 'majority' or training_mode == 'probability' or training_mode == 'crossentropy': 41 | # Cross Entropy. 42 | train_loss = ct.negate(ct.reduce_sum(ct.element_times(target, ct.log(prediction)), axis=-1)) 43 | elif training_mode == 'multi_target': 44 | train_loss = ct.negate(ct.log(ct.reduce_max(ct.element_times(target, prediction), axis=-1))) 45 | 46 | return train_loss 47 | 48 | def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs = 100): 49 | 50 | # create needed folders. 51 | output_model_path = os.path.join(base_folder, R'models') 52 | output_model_folder = os.path.join(output_model_path, model_name + '_' + training_mode) 53 | if not os.path.exists(output_model_folder): 54 | os.makedirs(output_model_folder) 55 | 56 | # creating logging file 57 | logging.basicConfig(filename = os.path.join(output_model_folder, "train.log"), filemode = 'w', level = logging.INFO) 58 | logging.getLogger().addHandler(logging.StreamHandler()) 59 | 60 | logging.info("Starting with training mode {} using {} model and max epochs {}.".format(training_mode, model_name, max_epochs)) 61 | 62 | # create the model 63 | num_classes = len(emotion_table) 64 | model = build_model(num_classes, model_name) 65 | 66 | # set the input variables. 67 | input_var = ct.input((1, model.input_height, model.input_width), np.float32) 68 | label_var = ct.input((num_classes), np.float32) 69 | 70 | # read FER+ dataset. 71 | logging.info("Loading data...") 72 | train_params = FERPlusParameters(num_classes, model.input_height, model.input_width, training_mode, False) 73 | test_and_val_params = FERPlusParameters(num_classes, model.input_height, model.input_width, "majority", True) 74 | 75 | train_data_reader = FERPlusReader.create(base_folder, train_folders, "label.csv", train_params) 76 | val_data_reader = FERPlusReader.create(base_folder, valid_folders, "label.csv", test_and_val_params) 77 | test_data_reader = FERPlusReader.create(base_folder, test_folders, "label.csv", test_and_val_params) 78 | 79 | # print summary of the data. 80 | display_summary(train_data_reader, val_data_reader, test_data_reader) 81 | 82 | # get the probalistic output of the model. 83 | z = model.model(input_var) 84 | pred = ct.softmax(z) 85 | 86 | epoch_size = train_data_reader.size() 87 | minibatch_size = 32 88 | 89 | # Training config 90 | lr_per_minibatch = [model.learning_rate]*20 + [model.learning_rate / 2.0]*20 + [model.learning_rate / 10.0] 91 | mm_time_constant = -minibatch_size/np.log(0.9) 92 | lr_schedule = ct.learning_rate_schedule(lr_per_minibatch, unit=ct.UnitType.minibatch, epoch_size=epoch_size) 93 | mm_schedule = ct.momentum_as_time_constant_schedule(mm_time_constant) 94 | 95 | # loss and error cost 96 | train_loss = cost_func(training_mode, pred, label_var) 97 | pe = ct.classification_error(z, label_var) 98 | 99 | # construct the trainer 100 | learner = ct.momentum_sgd(z.parameters, lr_schedule, mm_schedule) 101 | trainer = ct.Trainer(z, (train_loss, pe), learner) 102 | 103 | # Get minibatches of images to train with and perform model training 104 | max_val_accuracy = 0.0 105 | final_test_accuracy = 0.0 106 | best_test_accuracy = 0.0 107 | 108 | logging.info("Start training...") 109 | epoch = 0 110 | best_epoch = 0 111 | while epoch < max_epochs: 112 | train_data_reader.reset() 113 | val_data_reader.reset() 114 | test_data_reader.reset() 115 | 116 | # Training 117 | start_time = time.time() 118 | training_loss = 0 119 | training_accuracy = 0 120 | while train_data_reader.has_more(): 121 | images, labels, current_batch_size = train_data_reader.next_minibatch(minibatch_size) 122 | 123 | # Specify the mapping of input variables in the model to actual minibatch data to be trained with 124 | trainer.train_minibatch({input_var : images, label_var : labels}) 125 | 126 | # keep track of statistics. 127 | training_loss += trainer.previous_minibatch_loss_average * current_batch_size 128 | training_accuracy += trainer.previous_minibatch_evaluation_average * current_batch_size 129 | 130 | training_accuracy /= train_data_reader.size() 131 | training_accuracy = 1.0 - training_accuracy 132 | 133 | # Validation 134 | val_accuracy = 0 135 | while val_data_reader.has_more(): 136 | images, labels, current_batch_size = val_data_reader.next_minibatch(minibatch_size) 137 | val_accuracy += trainer.test_minibatch({input_var : images, label_var : labels}) * current_batch_size 138 | 139 | val_accuracy /= val_data_reader.size() 140 | val_accuracy = 1.0 - val_accuracy 141 | 142 | # if validation accuracy goes higher, we compute test accuracy 143 | test_run = False 144 | if val_accuracy > max_val_accuracy: 145 | best_epoch = epoch 146 | max_val_accuracy = val_accuracy 147 | 148 | trainer.save_checkpoint(os.path.join(output_model_folder, "model_{}".format(best_epoch))) 149 | 150 | test_run = True 151 | test_accuracy = 0 152 | while test_data_reader.has_more(): 153 | images, labels, current_batch_size = test_data_reader.next_minibatch(minibatch_size) 154 | test_accuracy += trainer.test_minibatch({input_var : images, label_var : labels}) * current_batch_size 155 | 156 | test_accuracy /= test_data_reader.size() 157 | test_accuracy = 1.0 - test_accuracy 158 | final_test_accuracy = test_accuracy 159 | if final_test_accuracy > best_test_accuracy: 160 | best_test_accuracy = final_test_accuracy 161 | 162 | logging.info("Epoch {}: took {:.3f}s".format(epoch, time.time() - start_time)) 163 | logging.info(" training loss:\t{:e}".format(training_loss)) 164 | logging.info(" training accuracy:\t\t{:.2f} %".format(training_accuracy * 100)) 165 | logging.info(" validation accuracy:\t\t{:.2f} %".format(val_accuracy * 100)) 166 | if test_run: 167 | logging.info(" test accuracy:\t\t{:.2f} %".format(test_accuracy * 100)) 168 | 169 | epoch += 1 170 | 171 | logging.info("") 172 | logging.info("Best validation accuracy:\t\t{:.2f} %, epoch {}".format(max_val_accuracy * 100, best_epoch)) 173 | logging.info("Test accuracy corresponding to best validation:\t\t{:.2f} %".format(final_test_accuracy * 100)) 174 | logging.info("Best test accuracy:\t\t{:.2f} %".format(best_test_accuracy * 100)) 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser() 178 | parser.add_argument("-d", 179 | "--base_folder", 180 | type = str, 181 | help = "Base folder containing the training, validation and testing data.", 182 | required = True) 183 | parser.add_argument("-m", 184 | "--training_mode", 185 | type = str, 186 | default='majority', 187 | help = "Specify the training mode: majority, probability, crossentropy or multi_target.") 188 | 189 | args = parser.parse_args() 190 | main(args.base_folder, args.training_mode) -------------------------------------------------------------------------------- /src/ferplus.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Microsoft. All rights reserved. 3 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 4 | # 5 | 6 | import sys 7 | import os 8 | import csv 9 | import numpy as np 10 | import logging 11 | import random as rnd 12 | from collections import namedtuple 13 | 14 | from PIL import Image 15 | from rect_util import Rect 16 | import img_util as imgu 17 | import matplotlib.pyplot as plt 18 | 19 | def display_summary(train_data_reader, val_data_reader, test_data_reader): 20 | ''' 21 | Summarize the data in a tabular format. 22 | ''' 23 | emotion_count = train_data_reader.emotion_count 24 | emotin_header = ['neutral', 'happiness', 'surprise', 'sadness', 'anger', 'disgust', 'fear', 'contempt'] 25 | 26 | logging.info("{0}\t{1}\t{2}\t{3}".format("".ljust(10), "Train", "Val", "Test")) 27 | for index in range(emotion_count): 28 | logging.info("{0}\t{1}\t{2}\t{3}".format(emotin_header[index].ljust(10), 29 | train_data_reader.per_emotion_count[index], 30 | val_data_reader.per_emotion_count[index], 31 | test_data_reader.per_emotion_count[index])) 32 | 33 | class FERPlusParameters(): 34 | ''' 35 | FER+ reader parameters 36 | ''' 37 | def __init__(self, target_size, width, height, training_mode = "majority", determinisitc = False, shuffle = True): 38 | self.target_size = target_size 39 | self.width = width 40 | self.height = height 41 | self.training_mode = training_mode 42 | self.determinisitc = determinisitc 43 | self.shuffle = shuffle 44 | 45 | class FERPlusReader(object): 46 | ''' 47 | A custom reader for FER+ dataset that support multiple modes as described in: 48 | https://arxiv.org/abs/1608.01041 49 | ''' 50 | @classmethod 51 | def create(cls, base_folder, sub_folders, label_file_name, parameters): 52 | ''' 53 | Factory function that create an instance of FERPlusReader and load the data form disk. 54 | ''' 55 | reader = cls(base_folder, sub_folders, label_file_name, parameters) 56 | reader.load_folders(parameters.training_mode) 57 | return reader 58 | 59 | def __init__(self, base_folder, sub_folders, label_file_name, parameters): 60 | ''' 61 | Each sub_folder contains the image files and a csv file for the corresponding label. The read iterate through 62 | all the sub_folders and aggregate all the images and their corresponding labels. 63 | ''' 64 | self.base_folder = base_folder 65 | self.sub_folders = sub_folders 66 | self.label_file_name = label_file_name 67 | self.emotion_count = parameters.target_size 68 | self.width = parameters.width 69 | self.height = parameters.height 70 | self.shuffle = parameters.shuffle 71 | self.training_mode = parameters.training_mode 72 | 73 | # data augmentation parameters.determinisitc 74 | if parameters.determinisitc: 75 | self.max_shift = 0.0 76 | self.max_scale = 1.0 77 | self.max_angle = 0.0 78 | self.max_skew = 0.0 79 | self.do_flip = False 80 | else: 81 | self.max_shift = 0.08 82 | self.max_scale = 1.05 83 | self.max_angle = 20.0 84 | self.max_skew = 0.05 85 | self.do_flip = True 86 | 87 | self.data = None 88 | self.per_emotion_count = None 89 | self.batch_start = 0 90 | self.indices = 0 91 | 92 | self.A, self.A_pinv = imgu.compute_norm_mat(self.width, self.height) 93 | 94 | def has_more(self): 95 | ''' 96 | Return True if there is more min-batches. 97 | ''' 98 | if self.batch_start < len(self.data): 99 | return True 100 | return False 101 | 102 | def reset(self): 103 | ''' 104 | Start from beginning for the new epoch. 105 | ''' 106 | self.batch_start = 0 107 | 108 | def size(self): 109 | ''' 110 | Return the number of images read by this reader. 111 | ''' 112 | return len(self.data) 113 | 114 | def next_minibatch(self, batch_size): 115 | ''' 116 | Return the next mini-batch, we do data augmentation during constructing each mini-batch. 117 | ''' 118 | data_size = len(self.data) 119 | batch_end = min(self.batch_start + batch_size, data_size) 120 | current_batch_size = batch_end - self.batch_start 121 | if current_batch_size < 0: 122 | raise Exception('Reach the end of the training data.') 123 | 124 | inputs = np.empty(shape=(current_batch_size, 1, self.width, self.height), dtype=np.float32) 125 | targets = np.empty(shape=(current_batch_size, self.emotion_count), dtype=np.float32) 126 | for idx in range(self.batch_start, batch_end): 127 | index = self.indices[idx] 128 | distorted_image = imgu.distort_img(self.data[index][1], 129 | self.data[index][3], 130 | self.width, 131 | self.height, 132 | self.max_shift, 133 | self.max_scale, 134 | self.max_angle, 135 | self.max_skew, 136 | self.do_flip) 137 | final_image = imgu.preproc_img(distorted_image, A=self.A, A_pinv=self.A_pinv) 138 | 139 | inputs[idx-self.batch_start] = final_image 140 | targets[idx-self.batch_start,:] = self._process_target(self.data[index][2]) 141 | 142 | self.batch_start += current_batch_size 143 | return inputs, targets, current_batch_size 144 | 145 | def load_folders(self, mode): 146 | ''' 147 | Load the actual images from disk. While loading, we normalize the input data. 148 | ''' 149 | self.reset() 150 | self.data = [] 151 | self.per_emotion_count = np.zeros(self.emotion_count, dtype=np.int) 152 | 153 | for folder_name in self.sub_folders: 154 | logging.info("Loading %s" % (os.path.join(self.base_folder, folder_name))) 155 | folder_path = os.path.join(self.base_folder, folder_name) 156 | in_label_path = os.path.join(folder_path, self.label_file_name) 157 | with open(in_label_path) as csvfile: 158 | emotion_label = csv.reader(csvfile) 159 | for row in emotion_label: 160 | # load the image 161 | image_path = os.path.join(folder_path, row[0]) 162 | image_data = Image.open(image_path) 163 | image_data.load() 164 | 165 | # face rectangle 166 | box = list(map(int, row[1][1:-1].split(','))) 167 | face_rc = Rect(box) 168 | 169 | emotion_raw = list(map(float, row[2:len(row)])) 170 | emotion = self._process_data(emotion_raw, mode) 171 | idx = np.argmax(emotion) 172 | if idx < self.emotion_count: # not unknown or non-face 173 | emotion = emotion[:-2] 174 | emotion = [float(i)/sum(emotion) for i in emotion] 175 | self.data.append((image_path, image_data, emotion, face_rc)) 176 | self.per_emotion_count[idx] += 1 177 | 178 | self.indices = np.arange(len(self.data)) 179 | if self.shuffle: 180 | np.random.shuffle(self.indices) 181 | 182 | def _process_target(self, target): 183 | ''' 184 | Based on https://arxiv.org/abs/1608.01041 the target depend on the training mode. 185 | 186 | Majority or crossentropy: return the probability distribution generated by "_process_data" 187 | Probability: pick one emotion based on the probability distribtuion. 188 | Multi-target: 189 | ''' 190 | if self.training_mode == 'majority' or self.training_mode == 'crossentropy': 191 | return target 192 | elif self.training_mode == 'probability': 193 | idx = np.random.choice(len(target), p=target) 194 | new_target = np.zeros_like(target) 195 | new_target[idx] = 1.0 196 | return new_target 197 | elif self.training_mode == 'multi_target': 198 | new_target = np.array(target) 199 | new_target[new_target>0] = 1.0 200 | epsilon = 0.001 # add small epsilon in order to avoid ill-conditioned computation 201 | return (1-epsilon)*new_target + epsilon*np.ones_like(target) 202 | 203 | def _process_data(self, emotion_raw, mode): 204 | ''' 205 | Based on https://arxiv.org/abs/1608.01041, we process the data differently depend on the training mode: 206 | 207 | Majority: return the emotion that has the majority vote, or unknown if the count is too little. 208 | Probability or Crossentropty: convert the count into probability distribution.abs 209 | Multi-target: treat all emotion with 30% or more votes as equal. 210 | ''' 211 | size = len(emotion_raw) 212 | emotion_unknown = [0.0] * size 213 | emotion_unknown[-2] = 1.0 214 | 215 | # remove emotions with a single vote (outlier removal) 216 | for i in range(size): 217 | if emotion_raw[i] < 1.0 + sys.float_info.epsilon: 218 | emotion_raw[i] = 0.0 219 | 220 | sum_list = sum(emotion_raw) 221 | emotion = [0.0] * size 222 | 223 | if mode == 'majority': 224 | # find the peak value of the emo_raw list 225 | maxval = max(emotion_raw) 226 | if maxval > 0.5*sum_list: 227 | emotion[np.argmax(emotion_raw)] = maxval 228 | else: 229 | emotion = emotion_unknown # force setting as unknown 230 | elif (mode == 'probability') or (mode == 'crossentropy'): 231 | sum_part = 0 232 | count = 0 233 | valid_emotion = True 234 | while sum_part < 0.75*sum_list and count < 3 and valid_emotion: 235 | maxval = max(emotion_raw) 236 | for i in range(size): 237 | if emotion_raw[i] == maxval: 238 | emotion[i] = maxval 239 | emotion_raw[i] = 0 240 | sum_part += emotion[i] 241 | count += 1 242 | if i >= 8: # unknown or non-face share same number of max votes 243 | valid_emotion = False 244 | if sum(emotion) > maxval: # there have been other emotions ahead of unknown or non-face 245 | emotion[i] = 0 246 | count -= 1 247 | break 248 | if sum(emotion) <= 0.5*sum_list or count > 3: # less than 50% of the votes are integrated, or there are too many emotions, we'd better discard this example 249 | emotion = emotion_unknown # force setting as unknown 250 | elif mode == 'multi_target': 251 | threshold = 0.3 252 | for i in range(size): 253 | if emotion_raw[i] >= threshold*sum_list: 254 | emotion[i] = emotion_raw[i] 255 | if sum(emotion) <= 0.5 * sum_list: # less than 50% of the votes are integrated, we discard this example 256 | emotion = emotion_unknown # set as unknown 257 | 258 | return [float(i)/sum(emotion) for i in emotion] --------------------------------------------------------------------------------