├── .gitignore ├── README.md ├── dataset └── README.md ├── requirements.txt ├── src ├── check.py ├── model.py ├── train.py └── visualize │ ├── grad_cam.py │ └── haarcascade_frontalface_default.xml ├── test ├── angry.jpg ├── guided_gradcam.jpg ├── happy.jpg ├── sad.jpg └── surprised.jpg └── trained ├── private_model_233_66.t7 └── public_model_236_64.t7 /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/* 2 | __pycache__ 3 | .t7 4 | .onnx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-facial-expression-recognition 2 | 3 | Lightweight facial emotion recognition. 4 | 5 | * Pytorch implementation for [Mini Xception](https://arxiv.org/pdf/1710.07557.pdf) inspired by [Keras implementation](https://github.com/oarriaga/face_classification). 6 | * Model size is about `250KB` 7 | 8 | ## Trained Model 9 | 10 | Trained by FER2013 dataset. 11 | 12 | * Private Data : 66% 13 | * Public Data : 64% 14 | 15 | Here is the result of sample image. 16 | Emotion | Probability | Guided Backprop | Grad-Cam | Guided Grad-Cam 17 | 18 | 19 | 20 | ## Retrain 21 | 22 | 1. see [here](./dataset/README.md) to prepare dataset. 23 | 24 | 2. execute train.py 25 | ``` 26 | cd src 27 | python train.py 28 | python check.py #check.py supports cpu only 29 | ``` 30 | 31 | ## Reference 32 | 33 | * [Grad-CAM](https://github.com/kazuto1011/grad-cam-pytorch) 34 | * [Data Augmentation / Optimizer](https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch) 35 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | * locate 'fer2013.csv' in this directory 2 | * you can download data from https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge/data 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | torchsummary 3 | torchvision 4 | numpy 5 | Pillow 6 | -------------------------------------------------------------------------------- /src/check.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import cv2 4 | import matplotlib.cm as cm 5 | import numpy as np 6 | import torch.hub 7 | import os 8 | import model 9 | from PIL import Image 10 | from torchvision import transforms 11 | from torchsummary import summary 12 | from visualize.grad_cam import BackPropagation, GradCAM,GuidedBackPropagation 13 | 14 | faceCascade = cv2.CascadeClassifier('./visualize/haarcascade_frontalface_default.xml') 15 | shape = (48,48) 16 | classes = [ 17 | 'Angry', 18 | 'Disgust', 19 | 'Fear', 20 | 'Happy', 21 | 'Sad', 22 | 'Surprised', 23 | 'Neutral' 24 | ] 25 | 26 | def preprocess(image_path): 27 | transform_test = transforms.Compose([ 28 | transforms.ToTensor() 29 | ]) 30 | image = cv2.imread(image_path) 31 | faces = faceCascade.detectMultiScale( 32 | image, 33 | scaleFactor=1.1, 34 | minNeighbors=5, 35 | minSize=(1, 1), 36 | flags=cv2.CASCADE_SCALE_IMAGE 37 | ) 38 | 39 | if len(faces) == 0: 40 | print('no face found') 41 | face = cv2.resize(image, shape) 42 | else: 43 | (x, y, w, h) = faces[0] 44 | face = image[y:y + h, x:x + w] 45 | face = cv2.resize(face, shape) 46 | 47 | img = Image.fromarray(face).convert('L') 48 | inputs = transform_test(img) 49 | return inputs, face 50 | 51 | 52 | def get_gradient_image(gradient): 53 | gradient = gradient.cpu().numpy().transpose(1, 2, 0) 54 | gradient -= gradient.min() 55 | gradient /= gradient.max() 56 | gradient *= 255.0 57 | return np.uint8(gradient) 58 | 59 | 60 | def get_gradcam_image(gcam, raw_image, paper_cmap=False): 61 | gcam = gcam.cpu().numpy() 62 | cmap = cm.jet_r(gcam)[..., :3] * 255.0 63 | if paper_cmap: 64 | alpha = gcam[..., None] 65 | gcam = alpha * cmap + (1 - alpha) * raw_image 66 | else: 67 | gcam = (cmap.astype(np.float) + raw_image.astype(np.float)) / 2 68 | return np.uint8(gcam) 69 | 70 | 71 | def guided_backprop(images, model_name): 72 | 73 | for i, image in enumerate(images): 74 | target, raw_image = preprocess(image['path']) 75 | image['image'] = target 76 | image['raw_image'] = raw_image 77 | 78 | net = model.Model(num_classes=len(classes)) 79 | checkpoint = torch.load(os.path.join('../trained', model_name), map_location=torch.device('cpu')) 80 | net.load_state_dict(checkpoint['net']) 81 | net.eval() 82 | summary(net, (1, shape[0], shape[1])) 83 | 84 | result_images = [] 85 | for index, image in enumerate(images): 86 | img = torch.stack([image['image']]) 87 | bp = BackPropagation(model=net) 88 | probs, ids = bp.forward(img) 89 | gcam = GradCAM(model=net) 90 | _ = gcam.forward(img) 91 | 92 | gbp = GuidedBackPropagation(model=net) 93 | _ = gbp.forward(img) 94 | 95 | # Guided Backpropagation 96 | actual_emotion = ids[:,0] 97 | gbp.backward(ids=actual_emotion.reshape(1,1)) 98 | gradients = gbp.generate() 99 | 100 | # Grad-CAM 101 | gcam.backward(ids=actual_emotion.reshape(1,1)) 102 | regions = gcam.generate(target_layer='last_conv') 103 | 104 | # Get Images 105 | label_image = np.zeros((shape[0],65, 3), np.uint8) 106 | cv2.putText(label_image, classes[actual_emotion.data], (5, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA) 107 | 108 | prob_image = np.zeros((shape[0],60,3), np.uint8) 109 | cv2.putText(prob_image, '%.1f%%' % (probs.data[:,0] * 100), (5, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) 110 | 111 | guided_bpg_image = get_gradient_image(gradients[0]) 112 | guided_bpg_image = cv2.merge((guided_bpg_image, guided_bpg_image, guided_bpg_image)) 113 | 114 | grad_cam_image = get_gradcam_image(gcam=regions[0, 0],raw_image=image['raw_image']) 115 | 116 | guided_gradcam_image = get_gradient_image(torch.mul(regions, gradients)[0]) 117 | guided_gradcam_image = cv2.merge((guided_gradcam_image, guided_gradcam_image, guided_gradcam_image)) 118 | 119 | img = cv2.hconcat([image['raw_image'],label_image,prob_image,guided_bpg_image,grad_cam_image,guided_gradcam_image]) 120 | result_images.append(img) 121 | print(image['path'],classes[actual_emotion.data], probs.data[:,0] * 100) 122 | 123 | cv2.imwrite('../test/guided_gradcam.jpg',cv2.resize(cv2.vconcat(result_images), None, fx=2,fy=2)) 124 | 125 | 126 | def main(): 127 | guided_backprop( 128 | images=[ 129 | {'path': '../test/angry.jpg'}, 130 | {'path': '../test/happy.jpg'}, 131 | {'path': '../test/sad.jpg'}, 132 | {'path': '../test/surprised.jpg'}, 133 | ], 134 | model_name='private_model_233_66.t7' 135 | ) 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SeparableConv2d(nn.Module): 5 | 6 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 7 | super(SeparableConv2d, self).__init__() 8 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, 9 | bias=bias) 10 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 11 | 12 | def forward(self, x): 13 | x = self.depthwise(x) 14 | x = self.pointwise(x) 15 | return x 16 | 17 | 18 | class ResidualBlock(nn.Module): 19 | 20 | def __init__(self, in_channeld, out_channels): 21 | super(ResidualBlock, self).__init__() 22 | 23 | self.residual_conv = nn.Conv2d(in_channels=in_channeld, out_channels=out_channels, kernel_size=1, stride=2, 24 | bias=False) 25 | self.residual_bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=1e-3) 26 | 27 | self.sepConv1 = SeparableConv2d(in_channels=in_channeld, out_channels=out_channels, kernel_size=3, bias=False, 28 | padding=1) 29 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.99, eps=1e-3) 30 | self.relu = nn.ReLU() 31 | 32 | self.sepConv2 = SeparableConv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, bias=False, 33 | padding=1) 34 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.99, eps=1e-3) 35 | self.maxp = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 36 | 37 | def forward(self, x): 38 | res = self.residual_conv(x) 39 | res = self.residual_bn(res) 40 | x = self.sepConv1(x) 41 | x = self.bn1(x) 42 | x = self.relu(x) 43 | x = self.sepConv2(x) 44 | x = self.bn2(x) 45 | x = self.maxp(x) 46 | return res + x 47 | 48 | 49 | class Model(nn.Module): 50 | 51 | def __init__(self, num_classes): 52 | super(Model, self).__init__() 53 | 54 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(8, affine=True, momentum=0.99, eps=1e-3) 56 | self.relu1 = nn.ReLU() 57 | self.conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3, stride=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(8, momentum=0.99, eps=1e-3) 59 | self.relu2 = nn.ReLU() 60 | 61 | self.module1 = ResidualBlock(in_channeld=8, out_channels=16) 62 | self.module2 = ResidualBlock(in_channeld=16, out_channels=32) 63 | self.module3 = ResidualBlock(in_channeld=32, out_channels=64) 64 | self.module4 = ResidualBlock(in_channeld=64, out_channels=128) 65 | 66 | self.last_conv = nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=3, padding=1) 67 | self.avgp = nn.AdaptiveAvgPool2d((1, 1)) 68 | 69 | def forward(self, input): 70 | x = input 71 | x = self.conv1(x) 72 | x = self.bn1(x) 73 | x = self.relu1(x) 74 | x = self.conv2(x) 75 | x = self.bn2(x) 76 | x = self.relu2(x) 77 | x = self.module1(x) 78 | x = self.module2(x) 79 | x = self.module3(x) 80 | x = self.module4(x) 81 | x = self.last_conv(x) 82 | x = self.avgp(x) 83 | x = x.view((x.shape[0], -1)) 84 | return x 85 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import model 6 | import csv 7 | from PIL import Image 8 | from torchvision.transforms import ToTensor 9 | from torch.utils.data import DataLoader 10 | 11 | if not torch.cuda.is_available(): 12 | from torchsummary import summary 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | shape = (44, 44) 17 | 18 | 19 | class DataSetFactory: 20 | 21 | def __init__(self): 22 | images = [] 23 | emotions = [] 24 | private_images = [] 25 | private_emotions = [] 26 | public_images = [] 27 | public_emotions = [] 28 | 29 | with open('../dataset/fer2013.csv', 'r') as csvin: 30 | data = csv.reader(csvin) 31 | next(data) 32 | for row in data: 33 | face = [int(pixel) for pixel in row[1].split()] 34 | face = np.asarray(face).reshape(48, 48) 35 | face = face.astype('uint8') 36 | 37 | if row[-1] == 'Training': 38 | emotions.append(int(row[0])) 39 | images.append(Image.fromarray(face)) 40 | elif row[-1] == "PrivateTest": 41 | private_emotions.append(int(row[0])) 42 | private_images.append(Image.fromarray(face)) 43 | elif row[-1] == "PublicTest": 44 | public_emotions.append(int(row[0])) 45 | public_images.append(Image.fromarray(face)) 46 | 47 | print('training size %d : private val size %d : public val size %d' % ( 48 | len(images), len(private_images), len(public_images))) 49 | train_transform = transforms.Compose([ 50 | transforms.RandomCrop(shape[0]), 51 | transforms.RandomHorizontalFlip(), 52 | ToTensor(), 53 | ]) 54 | val_transform = transforms.Compose([ 55 | transforms.CenterCrop(shape[0]), 56 | ToTensor(), 57 | ]) 58 | 59 | self.training = DataSet(transform=train_transform, images=images, emotions=emotions) 60 | self.private = DataSet(transform=val_transform, images=private_images, emotions=private_emotions) 61 | self.public = DataSet(transform=val_transform, images=public_images, emotions=public_emotions) 62 | 63 | 64 | class DataSet(torch.utils.data.Dataset): 65 | 66 | def __init__(self, transform=None, images=None, emotions=None): 67 | self.transform = transform 68 | self.images = images 69 | self.emotions = emotions 70 | 71 | def __getitem__(self, index): 72 | image = self.images[index] 73 | emotion = self.emotions[index] 74 | if self.transform is not None: 75 | image = self.transform(image) 76 | return image, emotion 77 | 78 | def __len__(self): 79 | return len(self.images) 80 | 81 | 82 | def main(): 83 | # variables ------------- 84 | batch_size = 128 85 | lr = 0.01 86 | epochs = 300 87 | learning_rate_decay_start = 80 88 | learning_rate_decay_every = 5 89 | learning_rate_decay_rate = 0.9 90 | # ------------------------ 91 | 92 | classes = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] 93 | network = model.Model(num_classes=len(classes)).to(device) 94 | if not torch.cuda.is_available(): 95 | summary(network, (1, shape[0], shape[1])) 96 | 97 | optimizer = torch.optim.SGD(network.parameters(), lr=lr, momentum=0.9, weight_decay=5e-3) 98 | criterion = nn.CrossEntropyLoss() 99 | factory = DataSetFactory() 100 | 101 | training_loader = DataLoader(factory.training, batch_size=batch_size, shuffle=True, num_workers=1) 102 | validation_loader = { 103 | 'private': DataLoader(factory.private, batch_size=batch_size, shuffle=True, num_workers=1), 104 | 'public': DataLoader(factory.public, batch_size=batch_size, shuffle=True, num_workers=1) 105 | } 106 | 107 | min_validation_loss = { 108 | 'private': 10000, 109 | 'public': 10000, 110 | } 111 | 112 | for epoch in range(epochs): 113 | network.train() 114 | total = 0 115 | correct = 0 116 | total_train_loss = 0 117 | if epoch > learning_rate_decay_start and learning_rate_decay_start >= 0: 118 | 119 | # 120 | frac = (epoch - learning_rate_decay_start) // learning_rate_decay_every 121 | decay_factor = learning_rate_decay_rate ** frac 122 | current_lr = lr * decay_factor 123 | for group in optimizer.param_groups: 124 | group['lr'] = current_lr 125 | else: 126 | current_lr = lr 127 | 128 | print('learning_rate: %s' % str(current_lr)) 129 | for i, (x_train, y_train) in enumerate(training_loader): 130 | optimizer.zero_grad() 131 | x_train = x_train.to(device) 132 | y_train = y_train.to(device) 133 | y_predicted = network(x_train) 134 | loss = criterion(y_predicted, y_train) 135 | loss.backward() 136 | optimizer.step() 137 | _, predicted = torch.max(y_predicted.data, 1) 138 | total_train_loss += loss.data 139 | total += y_train.size(0) 140 | correct += predicted.eq(y_train.data).sum() 141 | accuracy = 100. * float(correct) / total 142 | print('Epoch [%d/%d] Training Loss: %.4f, Accuracy: %.4f' % ( 143 | epoch + 1, epochs, total_train_loss / (i + 1), accuracy)) 144 | 145 | network.eval() 146 | with torch.no_grad(): 147 | for name in ['private', 'public']: 148 | total = 0 149 | correct = 0 150 | total_validation_loss = 0 151 | for j, (x_val, y_val) in enumerate(validation_loader[name]): 152 | x_val = x_val.to(device) 153 | y_val = y_val.to(device) 154 | y_val_predicted = network(x_val) 155 | val_loss = criterion(y_val_predicted, y_val) 156 | _, predicted = torch.max(y_val_predicted.data, 1) 157 | total_validation_loss += val_loss.data 158 | total += y_val.size(0) 159 | correct += predicted.eq(y_val.data).sum() 160 | 161 | accuracy = 100. * float(correct) / total 162 | if total_validation_loss <= min_validation_loss[name]: 163 | if epoch >= 10: 164 | print('saving new model') 165 | state = {'net': network.state_dict()} 166 | torch.save(state, '../trained/%s_model_%d_%d.t7' % (name, epoch + 1, accuracy)) 167 | min_validation_loss[name] = total_validation_loss 168 | 169 | print('Epoch [%d/%d] %s validation Loss: %.4f, Accuracy: %.4f' % ( 170 | epoch + 1, epochs, name, total_validation_loss / (j + 1), accuracy)) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /src/visualize/grad_cam.py: -------------------------------------------------------------------------------- 1 | # see https://github.com/kazuto1011/grad-cam-pytorch 2 | 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | 9 | class _BaseWrapper(object): 10 | """ 11 | Please modify forward() and backward() according to your task. 12 | """ 13 | 14 | def __init__(self, model): 15 | super(_BaseWrapper, self).__init__() 16 | self.device = next(model.parameters()).device 17 | self.model = model 18 | self.handlers = [] # a set of hook function handlers 19 | 20 | def _encode_one_hot(self, ids): 21 | one_hot = torch.zeros_like(self.logits).to(self.device) 22 | one_hot.scatter_(1, ids, 1.0) 23 | return one_hot 24 | 25 | def forward(self, image): 26 | """ 27 | Simple classification 28 | """ 29 | self.model.zero_grad() 30 | self.logits = self.model(image) 31 | self.probs = F.softmax(self.logits, dim=1) 32 | return self.probs.sort(dim=1, descending=True) 33 | 34 | def backward(self, ids): 35 | """ 36 | Class-specific backpropagation 37 | 38 | Either way works: 39 | 1. self.logits.backward(gradient=one_hot, retain_graph=True) 40 | 2. (self.logits * one_hot).sum().backward(retain_graph=True) 41 | """ 42 | 43 | one_hot = self._encode_one_hot(ids) 44 | self.logits.backward(gradient=one_hot, retain_graph=True) 45 | 46 | def generate(self): 47 | raise NotImplementedError 48 | 49 | def remove_hook(self): 50 | """ 51 | Remove all the forward/backward hook functions 52 | """ 53 | for handle in self.handlers: 54 | handle.remove() 55 | 56 | 57 | class BackPropagation(_BaseWrapper): 58 | def forward(self, image): 59 | self.image = image.requires_grad_() 60 | return super(BackPropagation, self).forward(self.image) 61 | 62 | def generate(self): 63 | gradient = self.image.grad.clone() 64 | self.image.grad.zero_() 65 | return gradient 66 | 67 | 68 | class GuidedBackPropagation(BackPropagation): 69 | """ 70 | "Striving for Simplicity: the All Convolutional Net" 71 | https://arxiv.org/pdf/1412.6806.pdf 72 | Look at Figure 1 on page 8. 73 | """ 74 | 75 | def __init__(self, model): 76 | super(GuidedBackPropagation, self).__init__(model) 77 | 78 | def backward_hook(module, grad_in, grad_out): 79 | # Cut off negative gradients 80 | if isinstance(module, nn.ReLU): 81 | return (torch.clamp(grad_in[0], min=0.0),) 82 | 83 | for module in self.model.named_modules(): 84 | self.handlers.append(module[1].register_backward_hook(backward_hook)) 85 | 86 | 87 | class GradCAM(_BaseWrapper): 88 | """ 89 | "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" 90 | https://arxiv.org/pdf/1610.02391.pdf 91 | Look at Figure 2 on page 4 92 | """ 93 | 94 | def __init__(self, model, candidate_layers=None): 95 | super(GradCAM, self).__init__(model) 96 | self.fmap_pool = OrderedDict() 97 | self.grad_pool = OrderedDict() 98 | self.candidate_layers = candidate_layers # list 99 | 100 | def forward_hook(key): 101 | def forward_hook_(module, input, output): 102 | # Save featuremaps 103 | self.fmap_pool[key] = output.detach() 104 | 105 | return forward_hook_ 106 | 107 | def backward_hook(key): 108 | def backward_hook_(module, grad_in, grad_out): 109 | # Save the gradients correspond to the featuremaps 110 | self.grad_pool[key] = grad_out[0].detach() 111 | 112 | return backward_hook_ 113 | 114 | # If any candidates are not specified, the hook is registered to all the layers. 115 | for name, module in self.model.named_modules(): 116 | if self.candidate_layers is None or name in self.candidate_layers: 117 | self.handlers.append(module.register_forward_hook(forward_hook(name))) 118 | self.handlers.append(module.register_backward_hook(backward_hook(name))) 119 | 120 | def _find(self, pool, target_layer): 121 | if target_layer in pool.keys(): 122 | return pool[target_layer] 123 | else: 124 | raise ValueError("Invalid layer name: {}".format(target_layer)) 125 | 126 | def _compute_grad_weights(self, grads): 127 | return F.adaptive_avg_pool2d(grads, 1) 128 | 129 | def forward(self, image): 130 | self.image_shape = image.shape[2:] 131 | return super(GradCAM, self).forward(image) 132 | 133 | def generate(self, target_layer): 134 | fmaps = self._find(self.fmap_pool, target_layer) 135 | grads = self._find(self.grad_pool, target_layer) 136 | weights = self._compute_grad_weights(grads) 137 | 138 | gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) 139 | gcam = F.relu(gcam) 140 | 141 | gcam = F.interpolate( 142 | gcam, self.image_shape, mode="bilinear", align_corners=False 143 | ) 144 | 145 | B, C, H, W = gcam.shape 146 | gcam = gcam.view(B, -1) 147 | gcam -= gcam.min(dim=1, keepdim=True)[0] 148 | gcam /= gcam.max(dim=1, keepdim=True)[0] 149 | gcam = gcam.view(B, C, H, W) 150 | 151 | return gcam 152 | 153 | 154 | -------------------------------------------------------------------------------- /test/angry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoshidan/pytorch-facial-expression-recognition/8a70b1ea54a790d6eef56692bbfe3dc36fa367d6/test/angry.jpg -------------------------------------------------------------------------------- /test/guided_gradcam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoshidan/pytorch-facial-expression-recognition/8a70b1ea54a790d6eef56692bbfe3dc36fa367d6/test/guided_gradcam.jpg -------------------------------------------------------------------------------- /test/happy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoshidan/pytorch-facial-expression-recognition/8a70b1ea54a790d6eef56692bbfe3dc36fa367d6/test/happy.jpg -------------------------------------------------------------------------------- /test/sad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoshidan/pytorch-facial-expression-recognition/8a70b1ea54a790d6eef56692bbfe3dc36fa367d6/test/sad.jpg -------------------------------------------------------------------------------- /test/surprised.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoshidan/pytorch-facial-expression-recognition/8a70b1ea54a790d6eef56692bbfe3dc36fa367d6/test/surprised.jpg -------------------------------------------------------------------------------- /trained/private_model_233_66.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoshidan/pytorch-facial-expression-recognition/8a70b1ea54a790d6eef56692bbfe3dc36fa367d6/trained/private_model_233_66.t7 -------------------------------------------------------------------------------- /trained/public_model_236_64.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoshidan/pytorch-facial-expression-recognition/8a70b1ea54a790d6eef56692bbfe3dc36fa367d6/trained/public_model_236_64.t7 --------------------------------------------------------------------------------