├── .gitignore ├── README.md ├── app.py ├── dataset_helper ├── IO_helper.py ├── TFRecords_helper.py ├── __init__.py ├── convert_csv2gray.py ├── convert_fer2013.py └── count_result.py ├── images ├── origin.png ├── result.png └── test_result.png ├── inference ├── __init__.py └── main_inference.py ├── labels.py ├── requirements ├── res └── haarcascade_frontalface_default.xml ├── result.xlsx ├── test ├── __init__.py └── main_test.py └── train ├── __init__.py └── main_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | .gitignore 3 | .idea/FaceEmotionalRecognition.iml 4 | .idea/inspectionProfiles/ 5 | .idea/misc.xml 6 | .idea/modules.xml 7 | .idea/workspace.xml 8 | datasets/ 9 | models/ 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FaceEmotionalRecognition 2 | 基于深度学习的表情情绪模型 3 | 4 | ## 基本说明 5 | 6 | 本系统是通过`tensorflow`实现的表情识别模型,支持识别如下几种表情(数据集为fer2013) 7 | 8 | ``` 9 | angry 10 | disgust 11 | fear 12 | happy 13 | sad 14 | surprise 15 | neutral 16 | ``` 17 | 18 | ## 目录结构 19 | 20 | + dataset_helper 21 | 22 | 包含一些数据集处理的脚本,包括将fer2013转换到tfrecords 23 | 24 | + images 25 | 26 | readme使用图片 27 | 28 | + inference 29 | 30 | 网络结构 31 | 32 | + models 33 | 34 | 训练好的模型 35 | 36 | + res 37 | 38 | opencv的人脸检测器 39 | 40 | + test 41 | 42 | 模型测试 43 | 44 | + train 45 | 46 | 模型训练 47 | 48 | + app.py 49 | 50 | 系统入口 51 | 52 | + labels.py 53 | 54 | 标签定义 55 | 56 | ## 系统主要依赖 57 | 58 | + python3.5 59 | + tensorflow 或 tensorflow-gpu 60 | + opencv-python 61 | 62 | ## 一些系统启动命令 63 | 64 | 执行app.py(需下载完成已经训练好的模型) 65 | 66 | ``` python 67 | # 开启实时视频识别模式 68 | python run.py -v 69 | # 识别单张图片 70 | python run.py -p path(图片路径) 71 | # 训练 72 | python run.py -train 73 | # 测试 74 | python run.py -test 75 | ``` 76 | 77 | 训练直接运行train\main_train.py(需下载完成训练集) 78 | 79 | 测试直接运行test\main_test.py(需下载完成测试集) 80 | 81 | ## 关于模型及数据集的获取 82 | 83 | 网盘地址 : https://pan.baidu.com/s/1FVna89oPvi4PiY-voMwEEA 提取码: knf4 84 | 85 | 下载完成后直接在本项目根目录解压 86 | 87 | ## 测试结果 88 | 89 | 见`result.xlsx` 90 | 91 | ![](https://github.com/XingToMax/FaceEmotionalRecognition/blob/master/images/test_result.png) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import requests 6 | from test.main_test import test,app_test, recognize_single_image 7 | from train.main_train import train 8 | 9 | def setup(): 10 | # os.system('pip install -r requirements.txt') 11 | os.system('mkdir models') 12 | print('create dir models') 13 | os.system('mkdir datasets') 14 | print('create dir datasets') 15 | # # download model 16 | # print('start download model') 17 | # print('start download checkpoint') 18 | # response = requests.get('http://tomax.xin/models/checkpoint') 19 | # with open('models/checkpoint', 'wb') as f: 20 | # f.write(response.content) 21 | # print('success download checkpoint') 22 | # print('start download model.ckpt.data-00000-of-00001') 23 | # response = requests.get('http://tomax.xin/models/model.ckpt.data-00000-of-00001') 24 | # with open('models/model.ckpt.data-00000-of-00001', 'wb') as f: 25 | # f.write(response.content) 26 | # print('success download model.ckpt.data-00000-of-00001') 27 | # print('start download model.ckpt.index') 28 | # response = requests.get('http://tomax.xin/models/model.ckpt.index') 29 | # with open('models/model.ckpt.index', 'wb') as f: 30 | # f.write(response.content) 31 | # print('success download model.ckpt.index') 32 | # print('start download model.ckpt.meta') 33 | # response = requests.get('http://tomax.xin/models/model.ckpt.meta') 34 | # with open('models/model.ckpt.meta', 'wb') as f: 35 | # f.write(response.content) 36 | # print('success download model.ckpt.meta') 37 | # print('success download model') 38 | # # download datasets 39 | # print('start download dataset') 40 | # print('start download train.tfrecords') 41 | # response = requests.get('http://tomax.xin/datasets/train.tfrecords') 42 | # with open('models/train.tfrecords', 'wb') as f: 43 | # f.write(response.content) 44 | # print('success download train.tfrecords') 45 | # print('start download train_enhance.tfrecords') 46 | # response = requests.get('http://tomax.xin/models/train_enhance.tfrecords') 47 | # with open('models/train_enhance.tfrecords', 'wb') as f: 48 | # f.write(response.content) 49 | # print('success download train_enhance.tfrecords') 50 | # print('start download test.tfrecords') 51 | # response = requests.get('http://tomax.xin/models/test_tfrecords') 52 | # with open('models/test_tfrecords', 'wb') as f: 53 | # f.write(response.content) 54 | # print('success download test.tfrecords') 55 | # print('start download val.tfrecords') 56 | # response = requests.get('http://tomax.xin/models/val_tfrecords') 57 | # with open('models/val_tfrecords', 'wb') as f: 58 | # f.write(response.content) 59 | print('success download val.tfrecords') 60 | 61 | def run_as_video(): 62 | app_test() 63 | 64 | def recognition_image(path): 65 | try: 66 | img = cv2.imread(path, 1) 67 | recognize_single_image(img) 68 | except Exception: 69 | print('invalid picture') 70 | print(path) 71 | 72 | 73 | if __name__ == '__main__': 74 | if (len(sys.argv) == 1): 75 | print('setup -- init the environment ,download the model and dataset') 76 | print('run -v -- run app in video') 77 | print('run -p path -- run app to recognition picture and input the path of the image') 78 | elif (sys.argv[1] == 'setup'): 79 | # setup() 80 | pass 81 | elif (sys.argv[1] == 'run'): 82 | # run() 83 | if len(sys.argv) == 3 and sys.argv[2] == '-v': 84 | run_as_video() 85 | elif len(sys.argv) == 3 and sys.argv[2] == '-train': 86 | train(os.getcwd() + '/datasets/train_enhance.tfrecords') 87 | elif len(sys.argv) == 3 and sys.argv[2] == '-test': 88 | test(os.getcwd() + '/datasets/val.tfrecords') 89 | elif len(sys.argv) >= 3 and sys.argv[2] == '-p': 90 | if len(sys.argv) < 4: 91 | print('please input the path of the picture') 92 | else: 93 | recognition_image(sys.argv[3]) 94 | else: 95 | print('invalid command') 96 | else: 97 | print('invalid command') -------------------------------------------------------------------------------- /dataset_helper/IO_helper.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import random 5 | import copy 6 | from labels import * 7 | 8 | # 图片资源根目录,具体应用时,需修改到本地相对应的路径 9 | date_resources_path = '../datasets/' 10 | # 训练集图片目录 11 | train_path = 'train/' 12 | # 测试集图片目录 13 | test_path = 'test/' 14 | # 验证集图片目录 15 | val_path = 'val/' 16 | # opencv 人脸检测 17 | face_patterns = cv2.CascadeClassifier( 18 | '../res/haarcascade_frontalface_default.xml') 19 | 20 | # 基于opencv进行人脸检测 21 | # 返回人脸图片及坐标 22 | # 默认的size 为42*42 23 | def detectFace(img, size = (42, 42)): 24 | pass 25 | 26 | class ImageObject: 27 | def __init__(self, data, label = -1): 28 | self.data = data 29 | self.label = label 30 | 31 | # 将图片归一化到[0,1]区间 32 | def encode_image(self, size = image_shape): 33 | # self.data = cv2.resize(self.data, image_shape, interpolation=cv2.INTER_CUBIC) 34 | # 这里/255是为了将像素值归一化到[0,1] 35 | self.data = self.data / 255. 36 | self.data = self.data.astype(np.float32) 37 | # self.data = np.reshape(self.data, (1, image_shape[0]*image_shape[1])) 38 | 39 | @staticmethod 40 | def encode(data, size=image_shape): 41 | img = cv2.resize(data, image_shape_2, interpolation=cv2.INTER_CUBIC) 42 | # img = data 43 | # 这里/255是为了将像素值归一化到[0,1] 44 | img = img / 255. 45 | img = img.astype(np.float32) 46 | return img 47 | # self.data = np.reshape(self.data, (1, image_shape[0]*image_shape[1])) 48 | 49 | # 恢复图片 50 | def decode_image(self, size = image_shape): 51 | pass 52 | 53 | # @staticmethod 54 | # def get_face(image): 55 | # 镜像图片 56 | @staticmethod 57 | def reverse_face(face): 58 | h, w = face.shape[0 : 2] 59 | res = copy.deepcopy(face) 60 | for i in range(h): 61 | for j in range(w): 62 | res[i, j] = face[i][w - j - 1] 63 | return res 64 | 65 | # 偏移图像, 取各角 66 | @staticmethod 67 | def crop_face(face, win = (36,36)): 68 | w, h = face.shape[0 : 2] 69 | face_lt = cv2.resize(face[0 : win[1], 0 : win[0]], (42,42)) 70 | face_rt = cv2.resize(face[0 : win[1], w - win[0]: w], (42, 42)) 71 | face_lb = cv2.resize(face[h - win[1] : h, 0: win[0]], (42, 42)) 72 | face_rb = cv2.resize(face[h - win[1] : h, w - win[0]: w], (42, 42)) 73 | # face_lt = face[0: win[1], 0: win[0]] 74 | # face_rt = face[0: win[1], w - win[0]: w] 75 | # face_lb = face[h - win[1]: h, 0: win[0]] 76 | # face_rb = face[h - win[1]: h, w - win[0]: w] 77 | return [face, face_lt, face_rt, face_lb, face_rb] 78 | # 给图片增加噪声 79 | @staticmethod 80 | def noiseing(img): 81 | # img = cv2.cvtColor(rgbimg, cv2.COLOR_BGR2GRAY) 82 | param = 30 83 | grayscale = 256 84 | w = img.shape[1] 85 | h = img.shape[0] 86 | newimg = np.zeros((h, w, 3), np.uint8) 87 | # row and col 88 | for x in range(0, h): 89 | for y in range(0, w, 2): # Avoid exceeding boundaries 90 | r1 = np.random.random_sample() 91 | r2 = np.random.random_sample() 92 | z1 = param * np.cos(2 * np.pi * r2) * np.sqrt((-2) * np.log(r1)) 93 | z2 = param * np.sin(2 * np.pi * r2) * np.sqrt((-2) * np.log(r1)) 94 | 95 | fxy_0 = int(img[x, y, 0] + z1) 96 | fxy_1 = int(img[x, y, 1] + z1) 97 | fxy_2 = int(img[x, y, 2] + z1) 98 | fxy1_0 = int(img[x, y + 1, 0] + z2) 99 | fxy1_1 = int(img[x, y + 1, 1] + z2) 100 | fxy1_2 = int(img[x, y + 1, 2] + z2) 101 | # f(x,y) 102 | if fxy_0 < 0: 103 | fxy_val_0 = 0 104 | elif fxy_0 > grayscale - 1: 105 | fxy_val_0 = grayscale - 1 106 | else: 107 | fxy_val_0 = fxy_0 108 | if fxy_1 < 0: 109 | fxy_val_1 = 0 110 | elif fxy_1 > grayscale - 1: 111 | fxy_val_1 = grayscale - 1 112 | else: 113 | fxy_val_1 = fxy_1 114 | if fxy_2 < 0: 115 | fxy_val_2 = 0 116 | elif fxy_2 > grayscale - 1: 117 | fxy_val_2 = grayscale - 1 118 | else: 119 | fxy_val_2 = fxy_2 120 | # f(x,y+1) 121 | if fxy1_0 < 0: 122 | fxy1_val_0 = 0 123 | elif fxy1_0 > grayscale - 1: 124 | fxy1_val_0 = grayscale - 1 125 | else: 126 | fxy1_val_0 = fxy1_0 127 | if fxy1_1 < 0: 128 | fxy1_val_1 = 0 129 | elif fxy1_1 > grayscale - 1: 130 | fxy1_val_1 = grayscale - 1 131 | else: 132 | fxy1_val_1 = fxy1_1 133 | if fxy1_2 < 0: 134 | fxy1_val_2 = 0 135 | elif fxy1_2 > grayscale - 1: 136 | fxy1_val_2 = grayscale - 1 137 | else: 138 | fxy1_val_2 = fxy1_2 139 | 140 | newimg[x, y, 0] = fxy_val_0 141 | newimg[x, y, 1] = fxy_val_1 142 | newimg[x, y, 2] = fxy_val_2 143 | newimg[x, y + 1, 0] = fxy1_val_0 144 | newimg[x, y + 1, 1] = fxy1_val_1 145 | newimg[x, y + 1, 2] = fxy1_val_2 146 | return newimg 147 | 148 | # 增强图片,对图像进行扩展 149 | @staticmethod 150 | def enhance_image(images): 151 | res = [] 152 | for image in images: 153 | target = ImageObject.crop_face(image) 154 | for img in target: 155 | res.append(img) 156 | res.append(ImageObject.reverse_face(img)) 157 | return res 158 | 159 | 160 | class ImageDataResource: 161 | def __init__(self): 162 | # 图片总数 163 | self.image_sum = 0 164 | # 分类图片数量 165 | self.kind_count = [] 166 | # 单张图片形状 167 | self.image_shape = [] 168 | # 图片列表,二维列表,0-6下标作分类 169 | self.data = [] 170 | # 乱序的数据,整合为一个数组,是ImageObject的集合 171 | self.un_seq_data = [] 172 | 173 | # 获取图片的形状 174 | def init_shape(self): 175 | self.image_shape = np.shape(self.data[0][0]) 176 | 177 | # 打乱数据 178 | def shuffle_data(self): 179 | for i in range(7): 180 | for j in range(self.kind_count[i]): 181 | image = ImageObject(self.data[i][j],i) 182 | self.un_seq_data.append(image) 183 | random.shuffle(self.un_seq_data) 184 | 185 | 186 | # 获取指定目录下的全部图片,并按分类整合 187 | # 返回值单个ImageDataResource 188 | def read_images(path): 189 | resource = ImageDataResource() 190 | # 按照表情类别遍历文件夹 191 | for kind in range(7): 192 | root_path = date_resources_path + path + str(kind) 193 | kind_num = 0 194 | images = [] 195 | for image in os.listdir(root_path): 196 | resource.image_sum = resource.image_sum + 1 197 | kind_num = kind_num + 1 198 | images.append(cv2.imread(root_path + '/' +image, 0)) 199 | resource.kind_count.append(kind_num) 200 | resource.data.append(images) 201 | resource.init_shape() 202 | return resource 203 | 204 | # 同上述方法,并对数据增强 205 | def read_images_enhance(path): 206 | resource = ImageDataResource() 207 | # 按照表情类别遍历文件夹 208 | for kind in range(7): 209 | root_path = date_resources_path + path + str(kind) 210 | kind_num = 0 211 | images = [] 212 | for image in os.listdir(root_path): 213 | resource.image_sum = resource.image_sum + 2 214 | kind_num = kind_num + 2 215 | face = cv2.imread(root_path + '/' +image, 0) 216 | face_reverse = ImageObject.reverse_face(face) 217 | # faces = ImageObject.enhance_image([face]) 218 | images.append(face) 219 | images.append(face_reverse) 220 | resource.kind_count.append(kind_num) 221 | resource.data.append(images) 222 | resource.init_shape() 223 | return resource 224 | 225 | # 同上述方法,并对数据增强,并入测试集 226 | def read_images_enhance_with_test(path, path_test): 227 | resource = ImageDataResource() 228 | # 按照表情类别遍历文件夹 229 | for kind in range(7): 230 | root_path = date_resources_path + path + str(kind) 231 | kind_num = 0 232 | images = [] 233 | for image in os.listdir(root_path): 234 | resource.image_sum = resource.image_sum + 2 235 | kind_num = kind_num + 2 236 | face = cv2.imread(root_path + '/' +image, 0) 237 | face_reverse = ImageObject.reverse_face(face) 238 | # faces = ImageObject.enhance_image([face]) 239 | images.append(face) 240 | images.append(face_reverse) 241 | resource.kind_count.append(kind_num) 242 | resource.data.append(images) 243 | 244 | # 按照表情类别遍历文件夹 245 | for kind in range(7): 246 | root_path = date_resources_path + path_test + str(kind) 247 | kind_num = 0 248 | for image in os.listdir(root_path): 249 | resource.image_sum = resource.image_sum + 2 250 | kind_num = kind_num + 2 251 | face = cv2.imread(root_path + '/' + image, 0) 252 | face_reverse = ImageObject.reverse_face(face) 253 | # faces = ImageObject.enhance_image([face]) 254 | resource.data[kind].append(face) 255 | resource.data[kind].append(face_reverse) 256 | resource.kind_count[kind] = resource.kind_count[kind] + kind_num 257 | resource.init_shape() 258 | return resource 259 | 260 | 261 | # 获取训练图片,返回值形式同read_images() 262 | def read_train_images(): 263 | return read_images(train_path) 264 | 265 | def read_enhance_train_images(): 266 | return read_images_enhance(train_path) 267 | 268 | def read_enhace_train_with_test_images(): 269 | return read_images_enhance_with_test(train_path, test_path) 270 | 271 | # 获取测试图片,返回值形式同read_images() 272 | def read_test_images(): 273 | return read_images(test_path) 274 | 275 | 276 | # 获取验证图片,返回值形式同read_images() 277 | def read_val_images(): 278 | return read_images(val_path) 279 | 280 | if __name__ == '__main__': 281 | # pass 282 | # resource = read_train_images() 283 | # resource.shuffle_data() 284 | # print(resource.image_sum) 285 | # print(len(resource.un_seq_data)) 286 | # print(np.shape(resource.data[0])) 287 | img = cv2.imread('E:/res/img/31.jpg', 1) 288 | newimg = ImageObject.reverse_face(img) 289 | cv2.imwrite('E:/res/img/32.jpg', newimg) 290 | # for i in range(len(resource.un_seq_data)): 291 | # print(resource.un_seq_data[i].label) 292 | # 293 | # cv2.imshow('1',resource.un_seq_data[0].data) 294 | # cv2.imshow('2', resource.un_seq_data[len(resource.un_seq_data) - 1].data) 295 | # cv2.waitKey(0) 296 | # cv2.destroyAllWindows() 297 | 298 | # image = cv2.imread('E:/tmp/emote-recognition/datasets/train/0/00000.jpg',0) 299 | # data = ImageObject(image, 0) 300 | # data.encode_image() 301 | # print(data.data) 302 | # print(np.shape(data.data)) 303 | # cv2.imshow('data',np.reshape(data.data,(48,48))) 304 | # cv2.waitKey(0) 305 | # cv2.destroyAllWindows() 306 | # images = ImageObject.enhance_image([image]) 307 | # for i in range(10): 308 | # cv2.imshow(str(i), images[i]) 309 | # cv2.resizeWindow(str(i), 640, 480); 310 | # 311 | # cv2.waitKey(0) 312 | # cv2.destroyAllWindows() -------------------------------------------------------------------------------- /dataset_helper/TFRecords_helper.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tensorflow as tf 4 | import sys 5 | 6 | from labels import * 7 | from dataset_helper.IO_helper import * 8 | 9 | face_patterns = cv2.CascadeClassifier( 10 | os.getcwd() + '/res/haarcascade_frontalface_default.xml') 11 | 12 | # recorder 的文件路径 13 | root_recorder_path = '../datasets/' 14 | train_recorder_path = root_recorder_path + 'train.tfrecords' 15 | val_recorder_path = root_recorder_path + 'val_tfrecords' 16 | test_recorder_path = root_recorder_path + 'test_tfrecords' 17 | train_recorder_enhance_path = root_recorder_path + 'train_enhance.tfrecords' 18 | train_recorder_enhance_with_test_path = root_recorder_path + 'train.enhance_with_test.tfrecords' 19 | 20 | 21 | # 将数据转化成对应的属性 22 | def _int64_feature(value): 23 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 24 | 25 | 26 | def _bytes_feature(value): 27 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 28 | 29 | 30 | def _float_feature(value): 31 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 32 | 33 | 34 | # 创建训练集 35 | def create_train_tf_records(): 36 | resource = read_train_images() 37 | resource.shuffle_data() 38 | create_tf_records(resource, train_recorder_path) 39 | print("train num : ", len(resource.un_seq_data)) 40 | 41 | def create_train_tf_records_enhance(): 42 | resource = read_enhance_train_images() 43 | resource.shuffle_data() 44 | create_tf_records(resource, train_recorder_enhance_path) 45 | print("train num : ", len(resource.un_seq_data)) 46 | 47 | def create_train_tf_records_enhance_with_test(): 48 | resource = read_enhace_train_with_test_images() 49 | resource.shuffle_data() 50 | create_tf_records(resource, train_recorder_enhance_with_test_path) 51 | print("train num : ", len(resource.un_seq_data)) 52 | 53 | # 创建测试集 54 | def create_test_tf_records(): 55 | resource = read_test_images() 56 | resource.shuffle_data() 57 | print("test num : ", len(resource.un_seq_data)) 58 | create_tf_records(resource, test_recorder_path) 59 | 60 | 61 | # 创建验证集 62 | def create_val_tf_records(): 63 | resource = read_val_images() 64 | resource.shuffle_data() 65 | print("val num : ", len(resource.un_seq_data)) 66 | create_tf_records(resource, val_recorder_path) 67 | 68 | 69 | # 依赖ImageDataResource创建tfrecords,输出到recorder_path 70 | def create_tf_records(resource, recorder_path): 71 | writer = tf.python_io.TFRecordWriter(recorder_path) 72 | for i in range(len(resource.un_seq_data)): 73 | if not i % 1000: 74 | print('data: {}/{}'.format(i, len(resource.un_seq_data))) 75 | sys.stdout.flush() 76 | img = resource.un_seq_data[i] 77 | img.encode_image() 78 | # 创建一个属性 79 | feature = {'label': _int64_feature(img.label), 80 | 'image': _bytes_feature(tf.compat.as_bytes(img.data.tostring()))} 81 | 82 | # 创建一个 example protocol buffer 83 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 84 | 85 | # 将上面的example protocol buffer写入文件 86 | writer.write(example.SerializeToString()) 87 | writer.close() 88 | sys.stdout.flush() 89 | 90 | 91 | # 从tf recordes 获取数据 92 | def decode_from_tf_records(filename_queue, is_batch, batch_size = 50, shape = image_shape_): 93 | reader = tf.TFRecordReader() 94 | _, serialized_example = reader.read(filename_queue) #返回文件名和文件 95 | features = tf.parse_single_example(serialized_example, 96 | features={ 97 | 'label': tf.FixedLenFeature([], tf.int64), 98 | 'image' : tf.FixedLenFeature([], tf.string), 99 | }) #取出包含image和label的feature对象 100 | image = tf.decode_raw(features['image'],tf.float32) 101 | image = tf.reshape(image, shape) 102 | label = tf.cast(features['label'], tf.int32) 103 | 104 | if is_batch: 105 | min_after_dequeue = 50 106 | capacity = min_after_dequeue + 3 * batch_size 107 | image, label = tf.train.shuffle_batch([image, label], 108 | batch_size=batch_size, 109 | num_threads=3, 110 | capacity=capacity, 111 | min_after_dequeue=min_after_dequeue) 112 | return image, label 113 | 114 | 115 | # 对数据集做一个简单处理,强耦合 116 | def create_one_hot(num, sum): 117 | label = [] 118 | for i in range(sum): 119 | if i == num: 120 | label.append(1) 121 | else: 122 | label.append(0) 123 | return label 124 | 125 | # 对数据集做一个简单处理,强耦合 126 | def modify_size(input_image, input_label, shape = [1764]): 127 | images = [] 128 | labels = [] 129 | for img in input_image: 130 | images.append(np.reshape(img, shape)) 131 | for label in input_label: 132 | labels.append(create_one_hot(label, 7)) 133 | return images, labels 134 | 135 | 136 | # 更新每轮的测试结果 137 | # pre : 预测结果 138 | # label : 标签 139 | # acc_record : 各种表情正确的数量 140 | # record : 各种表情的总数 141 | def check_accuracy(pre, label, acc_record, record): 142 | for i in range(len(pre)): 143 | if pre[i] == label[i] : 144 | acc_record[label[i]] = acc_record[label[i]] + 1 145 | record[label[i]] = record[label[i]] + 1 146 | 147 | return acc_record, record 148 | 149 | # 人脸探测 150 | def face_detect(image): 151 | faces = face_patterns.detectMultiScale(image, scaleFactor=1.1, minNeighbors=5, minSize=(100, 100)) 152 | return faces 153 | 154 | # 人脸表情标定 155 | def mark_human_emote(img, coors, emotes): 156 | x_limit, y_limit = img.shape[0 : 2] 157 | for i in range(len(coors)): 158 | x ,y, w, h = coors[i] 159 | img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 5) 160 | text_beg_x = x 161 | text_beg_y = y 162 | img = cv2.putText(img, emotes[i], (text_beg_x, text_beg_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3) 163 | return img 164 | 165 | if __name__ == '__main__': 166 | # create_train_tf_records() 167 | # create_test_tf_records() 168 | # create_val_tf_records() 169 | # make train.tfrecord 170 | create_train_tf_records_enhance_with_test() 171 | # image = cv2.imread('F:/ToMax/study/AI/datasets/extended-cohn-kanade-images/image/S005/001/S005_001_00000010.png', 1) 172 | # faces = face_detect(image) 173 | # image = mark_human_emote(image, faces, ['happy']) 174 | # cv2.imshow('face', image) 175 | # cv2.resizeWindow('face', 800, 600) 176 | # cv2.waitKey(0) 177 | # cv2.destroyAllWindows() 178 | # filename_queue = tf.train.string_input_producer([test_recorder_path], num_epochs=None) # 读入流中 179 | # train_image, train_label = decode_from_tf_records(filename_queue, is_batch=True) 180 | # with tf.Session() as sess: # 开始一个会话 181 | # init_op = tf.global_variables_initializer() 182 | # sess.run(init_op) 183 | # coord = tf.train.Coordinator() 184 | # threads = tf.train.start_queue_runners(coord=coord) 185 | # 186 | # try: 187 | # # while not coord.should_stop(): 188 | # for i in range(200): 189 | # example, l = sess.run([train_image, train_label]) # 在会话中取出image和label 190 | # print('train:') 191 | # print(np.shape(example)) 192 | # print(example) 193 | # print(l) 194 | # cv2.imshow('1',example[0]) 195 | # cv2.waitKey(0) 196 | # except tf.errors.OutOfRangeError: 197 | # print('Done reading') 198 | # finally: 199 | # coord.request_stop() 200 | # 201 | # coord.request_stop() 202 | # coord.join(threads) -------------------------------------------------------------------------------- /dataset_helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/dataset_helper/__init__.py -------------------------------------------------------------------------------- /dataset_helper/convert_csv2gray.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | datasets_path = r'.\datasets' 8 | train_csv = os.path.join(datasets_path, 'train.csv') 9 | val_csv = os.path.join(datasets_path, 'val.csv') 10 | test_csv = os.path.join(datasets_path, 'test.csv') 11 | 12 | train_set = os.path.join(datasets_path, 'train') 13 | val_set = os.path.join(datasets_path, 'val') 14 | test_set = os.path.join(datasets_path, 'test') 15 | 16 | for save_path, csv_file in [(train_set, train_csv), (val_set, val_csv), (test_set, test_csv)]: 17 | if not os.path.exists(save_path): 18 | os.makedirs(save_path) 19 | 20 | num = 1 21 | with open(csv_file) as f: 22 | csvr = csv.reader(f) 23 | header = next(csvr) 24 | for i, (label, pixel) in enumerate(csvr): 25 | pixel = np.asarray([float(p) for p in pixel.split()]).reshape(48, 48) 26 | subfolder = os.path.join(save_path, label) 27 | if not os.path.exists(subfolder): 28 | os.makedirs(subfolder) 29 | im = Image.fromarray(pixel).convert('L') 30 | image_name = os.path.join(subfolder, '{:05d}.jpg'.format(i)) 31 | print(image_name) 32 | im.save(image_name) -------------------------------------------------------------------------------- /dataset_helper/convert_fer2013.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | database_path = r'E:\tmp\fer2013\fer2013' 5 | datasets_path = r'.\datasets' 6 | csv_file = os.path.join(database_path, 'fer2013.csv') 7 | train_csv = os.path.join(datasets_path, 'train.csv') 8 | val_csv = os.path.join(datasets_path, 'val.csv') 9 | test_csv = os.path.join(datasets_path, 'test.csv') 10 | 11 | 12 | with open(csv_file) as f: 13 | csvr = csv.reader(f) 14 | header = next(csvr) 15 | rows = [row for row in csvr] 16 | 17 | trn = [row[:-1] for row in rows if row[-1] == 'Training'] 18 | csv.writer(open(train_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + trn) 19 | print(len(trn)) 20 | 21 | val = [row[:-1] for row in rows if row[-1] == 'PublicTest'] 22 | csv.writer(open(val_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + val) 23 | print(len(val)) 24 | 25 | tst = [row[:-1] for row in rows if row[-1] == 'PrivateTest'] 26 | csv.writer(open(test_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + tst) 27 | print(len(tst)) -------------------------------------------------------------------------------- /dataset_helper/count_result.py: -------------------------------------------------------------------------------- 1 | from labels import * 2 | 3 | class TestResult: 4 | def __init__(self): 5 | # 测试集各表情数量 6 | self.sample_num_count = TestResult.init_list() 7 | # 各表情正确数量 8 | self.real_num_count = TestResult.init_list() 9 | # 识别总数 10 | self.sum = 0 11 | # 正确总数 12 | self.correct = 0 13 | # 各表情识别情况 14 | self.recognition_list = [TestResult.init_list() for x in range(emote_kind_num)] 15 | 16 | def update(self, label_test, label_res): 17 | for i in range(len(label_test)): 18 | self.sample_num_count[label_test[i]] = self.sample_num_count[label_test[i]] + 1 19 | self.recognition_list[label_test[i]][label_res[i]] = self.recognition_list[label_test[i]][label_res[i]] + 1 20 | self.sum = self.sum + 1 21 | if label_test[i] == label_res[i]: 22 | self.real_num_count[label_test[i]] = self.real_num_count[label_test[i]] + 1 23 | self.correct = self.correct + 1 24 | 25 | def display(self): 26 | for i in range(emote_kind_num): 27 | print(emote_labels[i], '总数 :', self.sample_num_count[i], '正确率 :', TestResult.modify_to_percent(self.real_num_count[i] / self.sample_num_count[i])) 28 | table_names = '\t' 29 | for i in range(emote_kind_num): 30 | table_names += emote_labels[i] 31 | table_names += ' ' 32 | print(table_names) 33 | for i in range(emote_kind_num): 34 | col_val = '' 35 | for j in range(emote_kind_num): 36 | col_val += TestResult.modify_to_percent(self.recognition_list[i][j] / self.sample_num_count[i]) 37 | col_val += ' ' 38 | print(emote_labels[i], col_val) 39 | 40 | @staticmethod 41 | def init_list(): 42 | list = [] 43 | for i in range(emote_kind_num): 44 | list.append(0) 45 | return list 46 | 47 | @staticmethod 48 | def modify_to_percent(val): 49 | return format(val, '.2%') -------------------------------------------------------------------------------- /images/origin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/images/origin.png -------------------------------------------------------------------------------- /images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/images/result.png -------------------------------------------------------------------------------- /images/test_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/images/test_result.png -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/inference/__init__.py -------------------------------------------------------------------------------- /inference/main_inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | INPUT_NODE = 2304 4 | OUTPUT_NODE = 7 5 | 6 | IMAGE_SIZE = 48 7 | NUM_CHANNELS = 3 8 | NUM_LABELS = 10 9 | 10 | KEEP_PROP = 0.3 11 | 12 | KEEP_PROP_NEXT = 0.3 13 | 14 | # 定义卷积网络的前向传播过程。 15 | # input_tensor : 输入训练数据 16 | # train : 标记为训练过程还是测试过程,若为训练过程,需要进行dropout 17 | def inference(input_tensor, train): 18 | # 第一个卷积层,包括一个核大小5*5,步长为1,SAME模式的卷积核 19 | # 激励函数为relu 20 | # 输入为48 * 48 * 1 21 | # 输出为48 * 48 * 32 22 | with tf.variable_scope('layer1-conv1'): 23 | layer1_weight = tf.get_variable( 24 | "weight", [5,5,1,32], initializer=tf.truncated_normal_initializer(stddev=0.1) 25 | ) 26 | layer1_bias = tf.get_variable( 27 | "bias", [32], initializer=tf.constant_initializer(0.1)) 28 | layer1_conv = tf.nn.conv2d( 29 | input_tensor, layer1_weight, strides=[1,1,1,1], padding='SAME') 30 | layer1_relu = tf.nn.relu(tf.nn.bias_add(layer1_conv, layer1_bias)) 31 | 32 | print('第一个卷积层输出size', layer1_relu.get_shape().as_list()) 33 | 34 | # 第一个池化层,大小为3*3,步长为2, SAME 35 | # 输出为24 * 24 * 32 36 | with tf.variable_scope('layer2-pool1'): 37 | layer2_pool = tf.nn.max_pool(layer1_relu, ksize=[1,3,3,1], strides=[1,2,2,1], padding='SAME') 38 | 39 | print('第一个池化层输出size', layer2_pool.get_shape().as_list()) 40 | 41 | # 第二个卷积层 42 | # 激励函数为relu 43 | # 输出为24 * 24 * 32 44 | with tf.variable_scope('layer3-conv2'): 45 | layer3_weight = tf.get_variable( 46 | "weight", [4, 4, 32, 32], initializer=tf.truncated_normal_initializer(stddev=0.1) 47 | ) 48 | layer3_bias = tf.get_variable( 49 | "bias", [32], initializer=tf.constant_initializer(0.1)) 50 | layer3_conv = tf.nn.conv2d( 51 | layer2_pool, layer3_weight, strides=[1, 1, 1, 1], padding='SAME') 52 | layer3_relu = tf.nn.relu(tf.nn.bias_add(layer3_conv, layer3_bias)) 53 | 54 | print('第二个卷积层输出size', layer3_relu.get_shape().as_list()) 55 | 56 | # 第二个池化层 57 | # 输出为12 * 12 * 32 58 | with tf.variable_scope('layer4-pool2'): 59 | layer4_pool = tf.nn.max_pool(layer3_relu, ksize=[1,3,3,1], strides=[1,2,2,1], padding='SAME') 60 | 61 | print('第二个池化层输出size', layer4_pool.get_shape().as_list()) 62 | 63 | # 第三个卷积层 64 | # 激励函数为relu 65 | # 输出为12 * 12 * 64 66 | with tf.variable_scope('layer5-conv3'): 67 | layer5_weight = tf.get_variable( 68 | "weight", [5, 5, 32, 64], initializer=tf.truncated_normal_initializer(stddev=0.1) 69 | ) 70 | layer5_bias = tf.get_variable( 71 | "bias", [64], initializer=tf.constant_initializer(0.1)) 72 | layer5_conv = tf.nn.conv2d( 73 | layer4_pool, layer5_weight, strides=[1, 1, 1, 1], padding='SAME') 74 | layer5_relu = tf.nn.relu(tf.nn.bias_add(layer5_conv, layer5_bias)) 75 | 76 | print('第三个卷积层输出size', layer5_relu.get_shape().as_list()) 77 | 78 | # 第三个池化层 79 | # 输出为6 * 6 * 64 80 | with tf.variable_scope('layer6-pool3'): 81 | layer6_pool = tf.nn.max_pool(layer5_relu, ksize=[1,3,3,1], strides=[1,2,2,1], padding='SAME') 82 | 83 | print('第三个池化层输出size', layer6_pool.get_shape().as_list()) 84 | 85 | # 将各层展平,进入全连接层 86 | # 这里应该变为1 * 36 * 64 87 | # pool_shape = layer6_pool.get_shape().as_list() 88 | # nodes = pool_shape[1] * pool_shape[2] * pool_shape[3] 89 | # reshaped = tf.reshape(layer6_pool, [pool_shape[0], nodes]) 90 | reshaped = tf.reshape(layer6_pool, [-1, 2304]) 91 | 92 | # 第一个全连接层 93 | # 输出为1 * 1 * 2048 94 | with tf.variable_scope('layer7-fc1'): 95 | layer7_weight = tf.get_variable( 96 | "weight", [2304, 2048], initializer=tf.truncated_normal_initializer(stddev=0.1)) 97 | layer7_bias = tf.get_variable("bias", [2048], initializer=tf.constant_initializer(0.1)) 98 | layer7_fc = tf.nn.relu(tf.matmul(reshaped, layer7_weight) + layer7_bias) 99 | if train : 100 | layer7_fc = tf.nn.dropout(layer7_fc, KEEP_PROP, name='layer7') 101 | 102 | # 第二个全连接层 103 | # 输出为1 * 1 * 1024 104 | with tf.variable_scope('layer8-fc2'): 105 | layer8_weight = tf.get_variable( 106 | "weight", [2048, 1024], initializer=tf.truncated_normal_initializer(stddev=0.1)) 107 | layer8_bias = tf.get_variable("bias", [1024], initializer=tf.constant_initializer(0.1)) 108 | layer8_fc = tf.nn.relu(tf.matmul(layer7_fc, layer8_weight) + layer8_bias) 109 | if train: 110 | layer8_fc = tf.nn.dropout(layer8_fc, KEEP_PROP_NEXT, name='dropout') 111 | 112 | # 第三个全连接层 113 | # 输出为1 * 1 * 7 114 | with tf.variable_scope('layer9-fc3'): 115 | layer9_weight = tf.get_variable( 116 | "weight", [1024, 7], initializer=tf.truncated_normal_initializer(stddev=0.1)) 117 | layer9_bias = tf.get_variable("bias", [7], initializer=tf.constant_initializer(0.1)) 118 | logit = tf.matmul(layer8_fc, layer9_weight, name='logits') + layer9_bias 119 | return logit 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /labels.py: -------------------------------------------------------------------------------- 1 | # 表情分类标签 2 | emote_labels = [ 3 | 'angry', 4 | 'disgust', 5 | 'fear', 6 | 'happy', 7 | 'sad', 8 | 'surprise', 9 | 'neutral' 10 | ] 11 | 12 | emote_kind_num = 7 13 | 14 | image_shape = (42, 42) 15 | image_shape_ = [42, 42] 16 | 17 | image_shape_2 = (48, 48) 18 | image_shape_2_ = [48, 48] 19 | -------------------------------------------------------------------------------- /requirements: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | tensorflow-gpu 4 | -------------------------------------------------------------------------------- /result.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/result.xlsx -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/test/__init__.py -------------------------------------------------------------------------------- /test/main_test.py: -------------------------------------------------------------------------------- 1 | from inference.main_inference import inference 2 | from dataset_helper.TFRecords_helper import * 3 | from dataset_helper.count_result import * 4 | import tensorflow as tf 5 | import time 6 | 7 | BATCH_SIZE = 50 8 | 9 | MODEL_SAVE_PATH = os.getcwd() + "/models/model.ckpt" 10 | 11 | def test(path = val_recorder_path): 12 | filename_queue = tf.train.string_input_producer([path], num_epochs=None) # 读入流中 13 | 14 | test_image, test_label = decode_from_tf_records(filename_queue, is_batch=True, shape= image_shape_2_) 15 | 16 | x = tf.placeholder(tf.float32, [None, 2304], name="x-input") 17 | y_ = tf.placeholder(tf.float32, [None, 7], name="y-input") 18 | x_image = tf.reshape(x, [-1, 48, 48, 1]) 19 | y = inference(x_image, False) 20 | prediction = tf.argmax(y, 1) 21 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 22 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 23 | test_result = TestResult() 24 | with tf.control_dependencies([accuracy]): 25 | test_op = tf.no_op(name="test") 26 | saver = tf.train.Saver() 27 | with tf.Session() as sess: 28 | tf.global_variables_initializer().run() 29 | saver.restore(sess, MODEL_SAVE_PATH) 30 | coord = tf.train.Coordinator() 31 | threads = tf.train.start_queue_runners(coord=coord) 32 | acc_sum = 0 33 | count_sum = 0 34 | try: 35 | for i in range(1000): 36 | xs, ys = sess.run([test_image, test_label]) 37 | xs_, ys_ = modify_size(xs, ys, shape=[2304]) 38 | _, acc, pred = sess.run([test_op, accuracy, prediction], feed_dict={x: xs_, y_: ys_}) 39 | # print(ys) 40 | # print(pred) 41 | test_result.update(ys, pred) 42 | acc_sum = acc_sum + 50 * acc 43 | count_sum = count_sum + 50 44 | print("step by %d , training accuracy %g" % (i, acc)) 45 | print("accuracy : ", acc_sum / count_sum) 46 | test_result.display() 47 | # test_image, test_label = decode_from_tf_records(filename_queue, is_batch=True) 48 | # for i in range(100): 49 | # xs, ys = sess.run([test_image, test_label]) 50 | # xs_, ys_ = modify_size(xs, ys) 51 | # acc = sess.run(accuracy, feed_dict={x: xs_, y_: ys_}) 52 | # print("test step by %d, testing accuracy %g"%(i, acc)) 53 | 54 | except tf.errors.OutOfRangeError: 55 | print('Done reading') 56 | finally: 57 | coord.request_stop() 58 | 59 | coord.request_stop() 60 | coord.join(threads) 61 | 62 | def app_test(): 63 | x_ = tf.placeholder(tf.float32, [None, 2304], name="x-input") 64 | x_image = tf.reshape(x_, [-1, 48, 48, 1]) 65 | y_ = inference(x_image, False) 66 | prediction = tf.argmax(y_, 1) 67 | with tf.control_dependencies([prediction]): 68 | test_op = tf.no_op(name="test") 69 | saver = tf.train.Saver() 70 | with tf.Session() as sess: 71 | tf.global_variables_initializer().run() 72 | saver.restore(sess, MODEL_SAVE_PATH) 73 | capture = cv2.VideoCapture(0) 74 | while (31): 75 | ret, frame = capture.read() 76 | gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 77 | coors = face_detect(gray_frame) 78 | if len(coors) == 0: 79 | continue 80 | faces = [] 81 | for coor in coors: 82 | x, y, w, h = coor 83 | faces.append(ImageObject.encode(gray_frame[y : y + h, x : x + w], image_shape_2)) 84 | # face = ImageObject(gray_frame) 85 | # face.encode_image(size=image_shape_2) 86 | face_target, _ = modify_size(faces, [], [2304]) 87 | # infer, pred = sess.run([y, prediction], feed_dict={face_target)}) 88 | _, infer , pred = sess.run([test_op, y_, prediction], feed_dict={x_: face_target}) 89 | print(infer, pred) 90 | emotes = [] 91 | for res in pred: 92 | emotes.append(emote_labels[res]) 93 | frame = mark_human_emote(frame, coors, emotes) 94 | cv2.imshow("capture", frame) 95 | if cv2.waitKey(31) & 0xFF == ord('q'): 96 | break 97 | time.sleep(0.1) 98 | cv2.destroyAllWindows() 99 | 100 | def recognize_single_image(image): 101 | x_ = tf.placeholder(tf.float32, [None, 2304], name="x-input") 102 | x_image = tf.reshape(x_, [-1, 48, 48, 1]) 103 | inference_y = inference(x_image, False) 104 | prediction = tf.argmax(inference_y, 1) 105 | with tf.control_dependencies([prediction]): 106 | test_op = tf.no_op(name="test") 107 | saver = tf.train.Saver() 108 | with tf.Session() as sess: 109 | tf.global_variables_initializer().run() 110 | saver.restore(sess, MODEL_SAVE_PATH) 111 | coors = face_detect(image) 112 | gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 113 | faces = [] 114 | for coor in coors: 115 | x, y, w, h = coor 116 | faces.append(ImageObject.encode(gray_image[y: y + h, x: x + w], image_shape_2)) 117 | data, _ = modify_size(faces, [], [2304]) 118 | print(np.shape(data)) 119 | _, infer, pred = sess.run([test_op, inference_y, prediction], feed_dict={x_: data}) 120 | print(infer, pred) 121 | emotes = [] 122 | for res in pred: 123 | emotes.append(emote_labels[res]) 124 | frame = mark_human_emote(image, coors, emotes) 125 | cv2.imshow("display", frame) 126 | cv2.waitKey(0) 127 | cv2.destroyAllWindows() 128 | 129 | def main(args = None): 130 | # app_test() 131 | test() 132 | 133 | if __name__ == '__main__': 134 | tf.app.run() 135 | 136 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingToMax/FaceEmotionalRecognition/5101c3a9f489d6462f019eea9eaa9612c88cbddb/train/__init__.py -------------------------------------------------------------------------------- /train/main_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | from inference.main_inference import inference 5 | from dataset_helper.TFRecords_helper import * 6 | 7 | BATCH_SIZE = 50 8 | LEARNING_RATE = 0.0001 9 | TRAINING_STEPS = 100000 10 | 11 | MODEL_SAVE_PATH = "../models/model.ckpt" 12 | 13 | 14 | def train(path = train_recorder_enhance_path): 15 | filename_queue = tf.train.string_input_producer([path], num_epochs=None) # 读入流中 16 | 17 | train_image, train_label = decode_from_tf_records(filename_queue, is_batch=True, shape= image_shape_2_) 18 | 19 | x = tf.placeholder(tf.float32, [None, 2304], name="x-input") 20 | y_ = tf.placeholder(tf.float32, [None, 7], name="y-input") 21 | x_image = tf.reshape(x, [-1, 48, 48, 1]) 22 | y = inference(x_image, True) 23 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y, labels=y_)) 24 | optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(cost) 25 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 26 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 27 | 28 | with tf.control_dependencies([optimizer, accuracy]): 29 | train_op = tf.no_op(name="train") 30 | saver = tf.train.Saver() 31 | with tf.Session() as sess: 32 | tf.global_variables_initializer().run() 33 | coord = tf.train.Coordinator() 34 | threads = tf.train.start_queue_runners(coord=coord) 35 | 36 | try: 37 | for i in range(TRAINING_STEPS): 38 | xs, ys = sess.run([train_image, train_label]) 39 | xs_, ys_ = modify_size(xs,ys,shape=[2304]) 40 | if i % 50 == 0: 41 | acc = sess.run(accuracy, feed_dict={x:xs_, y_:ys_}) 42 | print("step by %d , training accuracy %g"%(i, acc)) 43 | 44 | _, cost_val, optimizer_val = sess.run([train_op, cost, optimizer], feed_dict={x : xs_, y_ : ys_}) 45 | # print(cost_val) 46 | 47 | # test_image, test_label = decode_from_tf_records(filename_queue, is_batch=True) 48 | # for i in range(100): 49 | # xs, ys = sess.run([test_image, test_label]) 50 | # xs_, ys_ = modify_size(xs, ys) 51 | # acc = sess.run(accuracy, feed_dict={x: xs_, y_: ys_}) 52 | # print("test step by %d, testing accuracy %g"%(i, acc)) 53 | saver.save(sess, MODEL_SAVE_PATH) 54 | 55 | except tf.errors.OutOfRangeError: 56 | print('Done reading') 57 | finally: 58 | coord.request_stop() 59 | 60 | coord.request_stop() 61 | coord.join(threads) 62 | # writer = tf.summary.FileWriter("log/train.log", tf.get_default_graph()) 63 | # writer.close() 64 | 65 | def main(args = None): 66 | train() 67 | 68 | if __name__ == '__main__': 69 | tf.app.run() 70 | 71 | --------------------------------------------------------------------------------