├── README.md ├── bottleneck.PNG ├── deconv.PNG ├── load_data.py ├── network.py ├── procrustes.py ├── run_demo.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Facial Action Unit Detection 2 | 3 | Facial Action Units or FAUs are specific regions of the face which allow one to estimate the intensity and type of emotion one is experiencing in an image. Each facial movement in the FAU coding system is encoded by its appearance on the face. AU coding starts from 0 and ends at 28 with 0th code representing a neutral face. While there are various FAUs present in the FAU coding system most applications use only a few of them to characterize emotions. These FAUs are commonly referred to as Main FAUs. For instance, the AU code for *cheek raise* (AU6) and *lip corner pull* (AU12) when combined together give the *happiness* (AU6+12) emotion. As another example, the AU codes *inner brow raise* (AU1), *brow low* (AU4) and *lip corner depress* (AU15) together give the *sadness* (AU1+4+15) emotion. AU codes start from 0 and end at 28 with the 0th code corresponding to *neutral face*. A complete list of AU codes and their corresponding facial expressions can be found on [Wikipedia](https://en.wikipedia.org/wiki/Facial_Action_Coding_System). 4 | 5 | A slight modification of the deep ResNet architecture pretrained on BP4D FAU dataset is utilized to predict the intensity of the facial expressions. The intensity labels for each AU range from 0 to 5 on a discrete integer scale. 6 | 7 | The repository contains following files- 8 | * `train.py`, `nerwork.py` and `run_demo.py` - These files are used to implement ResNet and train our model on a sample data for FAU Intensity prediction. 9 | * `procrustes.py`- For performing [procrustes analysis](https://bmvc2019.org/wp-content/uploads/papers/0403-paper.pdf) (facial alignment). 10 | * [model.pth](https://drive.google.com/file/d/1XYCoLtApTHq89_s3gf6Lh6U2kMP80oCH/view?usp=sharing) - Trained model on BP4D data 11 | 12 |

13 | 14 | Figure (left) presents modified the ResNet architecture consisting of bottleneck residual blocks and deconvolutional layers. The ResNet module is contained in the `ResNet` class which consists of `_make_deconv_layer` function. The `_make_deconv_layer` constructs a deconvolutional block by omitting the residual connections and replacing convolutions with deconvolutions (`ConvTranspose2d`). Figure (right) presents a single deconvolutional block consisting of 3 sub-blocks each having `ConvTranspose2d`, BatchNorm and ReLU layers. 15 | -------------------------------------------------------------------------------- /bottleneck.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siddhantmest/Facial-Action-Unit-Detection/691fcdbc181fa45f64726fec2bd5b7117f4ed860/bottleneck.PNG -------------------------------------------------------------------------------- /deconv.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siddhantmest/Facial-Action-Unit-Detection/691fcdbc181fa45f64726fec2bd5b7117f4ed860/deconv.PNG -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load the BP4D or the DISFA dataset. 3 | """ 4 | import os 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import inspect 9 | from torch.utils.data import Dataset 10 | 11 | class MyDatasets(Dataset): 12 | def __init__(self, sigma=2, size=256, heatmap=32,AU_positions=10, database=''): 13 | if database == 'train': 14 | txt_file = open('./data/examples.txt','r') 15 | if database == 'demo': 16 | txt_file = open('./test/examples.txt','r') 17 | lines = txt_file.readlines()[0::] 18 | names = [l.split()[0] for l in lines] 19 | coords = [l.split()[1::] for l in lines] 20 | self.database = database 21 | self.data = dict(zip(names,coords)) 22 | self.imgs = list(set(names)) 23 | self.len = len(self.imgs) 24 | 25 | def generate_target(self, points, intensity): 26 | target = np.zeros((self.AU_positions,self.heatmap,self.heatmap),dtype=np.float32) 27 | gs_range = self.sigma * 15 28 | for point_id in range(self.AU_positions): 29 | feat_stride = self.size / self.heatmap 30 | mu_x = int(points[point_id][0] / feat_stride + 0.5) 31 | mu_y = int(points[point_id][1] / feat_stride + 0.5) 32 | ul = [int(mu_x - gs_range), int(mu_y - gs_range)] 33 | br = [int(mu_x + gs_range + 1), int(mu_y + gs_range + 1)] 34 | x = np.arange(0, 2*gs_range+1, 1, np.float32) 35 | y = x[:, np.newaxis] 36 | x_center = y_center = (2*gs_range+1) // 2 37 | g = np.exp(- ((x - x_center) ** 2 + (y - y_center) ** 2) / (2 * self.sigma ** 2)) 38 | g_x = max(0, -ul[0]), min(br[0], self.heatmap) - ul[0] 39 | g_y = max(0, -ul[1]), min(br[1], self.heatmap) - ul[1] 40 | img_x = max(0, ul[0]), min(br[0], self.heatmap) 41 | img_y = max(0, ul[1]), min(br[1], self.heatmap) 42 | target[point_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = intensity[point_id]*g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 43 | return target*255.0 44 | 45 | def fetch(self,index): 46 | path_to_img = self.imgs[index] 47 | image = cv2.cvtColor(cv2.imread(path_to_img), cv2.COLOR_BGR2RGB) 48 | if self.database == 'demo': 49 | return image, [0], [0] 50 | AUs = self.data[self.imgs[index]] 51 | AUs = np.float32(self.data[self.imgs[index]]).reshape(-1,3) 52 | AU_coords = AUs[:,:2] 53 | AU_intensity = AUs[:,2] 54 | return image, AU_coords, AU_intensity 55 | 56 | def __getitem__(self,index): 57 | image, AU_coords, AU_intensity = self.fetch(index) 58 | nimg = len(image) 59 | sample = dict.fromkeys(['Im'], None) 60 | out = dict.fromkeys(['image','points']) 61 | image_np = torch.from_numpy((image/255.0).swapaxes(2,1).swapaxes(1,0)) 62 | out['image'] = image_np.type_as(torch.FloatTensor()) 63 | out['AU_coords'] = np.floor(AU_coords) 64 | if self.database not in ['demo','train']: 65 | target = self.generate_target(out['AU_coords'], AU_intensity) 66 | target = torch.from_numpy(target).type_as(torch.FloatTensor()) 67 | sample['target'] = target 68 | sample['pts'] = out['AU_coords'] 69 | sample['intensity'] = AU_intensity 70 | sample['Im'] = out['image'] 71 | return sample 72 | 73 | def __len__(self): 74 | return len(self.imgs) 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | 11 | BN_MOMENTUM = 0.1 # momentum parameter for BatchNorm 12 | logger = logging.getLogger(__name__) 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | return nn.Conv2d( 16 | in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False 18 | ) 19 | 20 | class Bottleneck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(Bottleneck, self).__init__() 25 | self.downsample = downsample 26 | self.stride = stride 27 | self.momentum = BN_MOMENTUM # momentum parameter for BatchNorm 28 | 29 | # use for bottleneck block 30 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) 31 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 32 | self.conv2 = conv3x3(planes, planes, stride=stride) 33 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 34 | self.conv3 = nn.Conv2d(planes, planes*self.expansion, kernel_size=1, stride=1, bias=False) 35 | self.bn3 = nn.BatchNorm2d(planes*self.expansion, momentum=BN_MOMENTUM) 36 | self.relu = nn.ReLU() 37 | 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | # use for bottleneck block 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv3(out) 52 | out = self.bn3(out) 53 | 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | return out 61 | 62 | class ResNet(nn.Module): 63 | #Res-50 Bottleneck,[3,4,6,3] 64 | def __init__(self, block=Bottleneck, num_maps=10, layers=[3,4,6,3]): 65 | self.inplanes = 64 66 | self.deconv_with_bias = False 67 | 68 | super(ResNet, self).__init__() 69 | 70 | 71 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 72 | bias=False) 73 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 76 | self.layer1 = self._make_layer(block, 64, layers[0]) 77 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 80 | 81 | # use for deconv blocks 82 | self.deconv_layers_1 = self._make_deconv_layer(0,[256, 256, 256],[4, 4, 4]) 83 | self.deconv_layers_2 = self._make_deconv_layer(0,[256, 256, 256],[4, 4, 4]) 84 | self.deconv_layers_3 = self._make_deconv_layer(0,[256, 256, 256],[4, 4, 4]) 85 | 86 | 87 | self.final_layer = nn.Conv2d( 88 | in_channels=256, 89 | out_channels=num_maps, 90 | kernel_size=1, 91 | stride=1, 92 | padding=0 93 | ) 94 | 95 | self._initialize_weights() 96 | 97 | def _make_layer(self, block, planes, blocks, stride=1): 98 | downsample = None 99 | if stride != 1 or self.inplanes != planes * block.expansion: 100 | downsample = nn.Sequential( 101 | nn.Conv2d(self.inplanes, planes * block.expansion, 102 | kernel_size=1, stride=stride, bias=False), 103 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 104 | ) 105 | 106 | layers = [] 107 | layers.append(block(self.inplanes, planes, stride, downsample)) 108 | self.inplanes = planes * block.expansion 109 | for i in range(1, blocks): 110 | layers.append(block(self.inplanes, planes)) 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def _get_deconv_cfg(self, deconv_kernel, index): 115 | if deconv_kernel == 4: 116 | padding = 1 117 | output_padding = 0 118 | elif deconv_kernel == 3: 119 | padding = 1 120 | output_padding = 1 121 | elif deconv_kernel == 2: 122 | padding = 0 123 | output_padding = 0 124 | 125 | return deconv_kernel, padding, output_padding 126 | 127 | def _make_deconv_layer(self, i , num_filters, num_kernels): 128 | 129 | layers = [] 130 | 131 | kernel, padding, output_padding = \ 132 | self._get_deconv_cfg(num_kernels[i], i) 133 | 134 | planes = num_filters[i] 135 | layers.append( 136 | nn.ConvTranspose2d( 137 | in_channels=self.inplanes, 138 | out_channels=planes, 139 | kernel_size=kernel, 140 | stride=2, 141 | padding=padding, 142 | output_padding=output_padding, 143 | bias=self.deconv_with_bias)) 144 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) 145 | layers.append(nn.ReLU(inplace=True)) 146 | self.inplanes = planes 147 | 148 | return nn.Sequential(*layers) 149 | 150 | def forward(self, x): 151 | x = self.conv1(x) 152 | x = self.bn1(x) 153 | x = self.relu(x) 154 | x = self.maxpool(x) 155 | 156 | x = self.layer1(x) 157 | x = self.layer2(x) 158 | x = self.layer3(x) 159 | x = self.layer4(x) 160 | 161 | # use for deconv blocks 162 | x = self.deconv_layers_1(x) 163 | x = self.deconv_layers_2(x) 164 | x = self.deconv_layers_3(x) 165 | 166 | x = self.final_layer(x) 167 | 168 | return x 169 | 170 | def _initialize_weights(self): 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | if m.bias is not None: 176 | m.bias.data.zero_() 177 | elif isinstance(m, nn.BatchNorm2d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | elif isinstance(m, nn.Linear): 181 | n = m.weight.size(1) 182 | m.weight.data.normal_(0, 0.01) 183 | m.bias.data.zero_() 184 | 185 | -------------------------------------------------------------------------------- /procrustes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dlib 3 | import matplotlib.pyplot as plt 4 | import imutils 5 | import cv2 6 | import os 7 | 8 | PREDICTOR_PATH = "./data/shape_predictor_68_face_landmarks.dat" 9 | 10 | txt_file = open('./test2/examples.txt','r') 11 | lines = txt_file.readlines() 12 | new_paths = [] 13 | for ps in lines: 14 | new_paths.append(ps.replace('\r', '').replace('\n', '')) 15 | ref_path = new_paths[1] 16 | new_paths.remove(ref_path) 17 | file_path = './test2_1' 18 | 19 | 20 | def procrustes_analysys(A, B): 21 | """Procrustes analysis 22 | Basic algorithm is 23 | 1. Recenter the points based on their mean: compute a mean and subtract it from every points in shape 24 | 2. Normalize 25 | 3. Rotate one of the shapes and find MSE 26 | Args: 27 | A: 28 | B: 29 | Returns: 30 | """ 31 | h_A, w_A = A.shape 32 | h_B, w_B = B.shape 33 | 34 | # compute mean of each A and B 35 | Amu = np.mean(A, axis=0) 36 | Bmu = np.mean(B, axis=0) 37 | 38 | # subtract a mean 39 | A_base = A - Amu 40 | B_base = B - Bmu 41 | 42 | # normalize 43 | ssum_A = (A_base**2).sum() 44 | ssum_B = (B_base**2).sum() 45 | 46 | norm_A = np.sqrt(ssum_A) 47 | norm_B = np.sqrt(ssum_B) 48 | 49 | normalized_A = A_base / norm_A 50 | normalized_B = B_base / norm_B 51 | 52 | if (w_B < w_A): 53 | normalized_B = np.concatenate((normalized_B, np.zeros(h_A, w_A - w_B)), 0) 54 | 55 | A = np.dot(normalized_A.T, normalized_B) 56 | 57 | # SVD 58 | u, s, vh = np.linalg.svd(A, full_matrices=False) 59 | v = vh.T 60 | T = np.dot(v, u.T) 61 | 62 | scale = norm_A / norm_B 63 | 64 | return T, scale 65 | 66 | def shape_to_np(shape, dtype="int"): 67 | """Take a shape object and convert it to numpy array 68 | Args: 69 | shape: an object returned by dlib face landmark detector containing the 68 (x, y)-coordinates of the facial landmark regions 70 | dtype: int 71 | Returns: 72 | coords: (68,2) numpy array 73 | """ 74 | coords = np.zeros((68, 2), dtype=dtype) 75 | 76 | for i in range(0, 68): 77 | coords[i] = (shape.part(i).x, shape.part(i).y) 78 | 79 | return coords 80 | 81 | def get_face(img1_detection): 82 | for face in img1_detection: 83 | x = face.left() 84 | y = face.top() 85 | w = face.right() - x 86 | h = face.bottom() - y 87 | 88 | # draw box over face 89 | cv2.rectangle(img1, (x,y), (x+w,y+h), (0,255,0), 2) 90 | 91 | img_height, img_width = img1.shape[:2] 92 | cv2.putText(img1, "HOG", (img_width-50,20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 93 | (0,255,0), 2) 94 | 95 | # display output image 96 | plt.imshow(img1) 97 | 98 | def plot_landmarks(): 99 | # plot facial landmarks on the image 100 | for (x, y) in img1_shape: 101 | cv2.circle(img1, (x, y), 1, (0, 0, 255), -1) 102 | 103 | plt.imshow(img1) 104 | 105 | 106 | detector = dlib.get_frontal_face_detector() 107 | predictor = dlib.shape_predictor(PREDICTOR_PATH) 108 | 109 | for path in new_paths: 110 | img1 = dlib.load_rgb_image(ref_path) 111 | img2 = dlib.load_rgb_image(path) 112 | 113 | img1_detection = detector(img1, 1) 114 | img2_detection = detector(img2, 1) 115 | 116 | img1_shape = predictor(img1, img1_detection[0]) 117 | img2_shape = predictor(img2, img2_detection[0]) 118 | 119 | img1_shape1 = shape_to_np(img1_shape) 120 | img2_shape1 = shape_to_np(img2_shape) 121 | 122 | M, scale = procrustes_analysys(img1_shape1, img2_shape1) 123 | theta = np.rad2deg(np.arccos(M[0][0])) 124 | #print("theta is {}".format(theta)) 125 | 126 | rotation_matrix = cv2.getRotationMatrix2D((img1.shape[1]/2, img1.shape[0]/2), theta, 1) 127 | dst = cv2.warpAffine(img2, rotation_matrix, (img2.shape[1], img2.shape[0])) 128 | 129 | 130 | img2_aligned = dlib.get_face_chip(dst, img2_shape) 131 | 132 | 133 | img2_aligned_resized = imutils.resize(img2_aligned, width = img2.shape[0], height = img2.shape[1]) 134 | 135 | 136 | img_name = path.split('/')[-1].split(".")[0] + '.jpg' 137 | cv2.imwrite(os.path.join(file_path , img_name), img2_aligned_resized) 138 | 139 | if path == new_paths[-1]: 140 | img1_aligned = dlib.get_face_chip(img1, img1_shape) 141 | img1_aligned_resized = imutils.resize(img1_aligned, width = img1.shape[0], height = img1.shape[1]) 142 | img_name = ref_path.split('/')[-1].split(".")[0] + '.jpg' 143 | cv2.imwrite(os.path.join(file_path , img_name), img1_aligned_resized) -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | from load_data import MyDatasets 3 | from utils import * 4 | from network import ResNet 5 | from torch.utils.data import Dataset, DataLoader 6 | import os, pickle 7 | from torchvision.utils import save_image 8 | import argparse 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--K', default=10, help='Number of AU positions')#24,10 14 | parser.add_argument('--dataset', default='BP4D', type=str, help='database')#BP4D,DISFA 15 | parser.add_argument('--dataset_test', default='demo', type=str)#BP4D-val, DISFA-val 16 | parser.add_argument('--model_path', type=str,default='./model/model.pth', help='model path') #model.pth 17 | parser.add_argument('--cuda', default='10', type=str, help='cuda') 18 | parser.add_argument('--size', default=256, help='Image size') 19 | 20 | def loadnet(npoints=10,path_to_model=None): 21 | # Load the trained model. 22 | net = ResNet(num_maps=npoints) 23 | checkpoint = torch.load(path_to_model, map_location='cpu') 24 | checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()} 25 | net.load_state_dict(checkpoint,strict=False) 26 | return net.to('cpu') 27 | 28 | def predict(loader,OUT,net): 29 | preds = [] 30 | au17_all_subjects = [] 31 | x_labels = ["FAU06", "FAU10", "FAU12", "FAU14", "FAU17"] 32 | all_intensities = [] 33 | with torch.no_grad(): 34 | count = 0 35 | for sample in loader: 36 | img = sample['Im'] 37 | heatmap = net(img) 38 | out = OUT(heatmap) 39 | preds.append(out) 40 | images = None 41 | maps_AU6 = None 42 | maps_AU10 = None 43 | maps_AU12= None 44 | maps_AU14= None 45 | maps_AU17= None 46 | threshold = 0.1 47 | font = cv2.FONT_HERSHEY_SIMPLEX 48 | 49 | for (index,item) in enumerate(sample['Im'].to('cpu').detach()): 50 | img_ori = (255*item.permute(1,2,0).numpy()).astype(np.uint8).copy() 51 | AU_intensities = out[index] 52 | AU06_intensity = round((out[index][0]+out[index][1])/2.0,2) 53 | AU10_intensity = round((out[index][2]+out[index][3])/2.0,2) 54 | AU12_intensity = round((out[index][4]+out[index][5])/2.0,2) 55 | AU14_intensity = round((out[index][6]+out[index][7])/2.0,2) 56 | AU17_intensity = round((out[index][8]+out[index][9])/2.0,2) 57 | 58 | au17_all_subjects.append(AU17_intensity) 59 | """ 60 | Visualization of the predicted AU6 heatmap. 61 | """ 62 | heatmap_AU6_0 = heatmap[index][0].to('cpu').detach() 63 | heatmap_AU6_0[heatmap_AU6_0255*5.0]=255*5.0 65 | heatmap_AU6_0_np = (heatmap_AU6_0.numpy()/5.0).astype(np.uint8).copy() 66 | heatmap_AU6_0_rz = cv2.resize(heatmap_AU6_0_np,(256,256)) 67 | map_AU6_0 = cv2.applyColorMap(heatmap_AU6_0_rz, cv2.COLORMAP_JET) 68 | map_AU6_0=cv2.cvtColor(map_AU6_0, cv2.COLOR_RGB2BGR) 69 | heatmap_AU6_1 = heatmap[index][1].to('cpu').detach() 70 | heatmap_AU6_1[heatmap_AU6_1255*5.0]=255*5.0 72 | heatmap_AU6_1_np = (heatmap_AU6_1.numpy()/5.0).astype(np.uint8).copy() 73 | heatmap_AU6_1_rz = cv2.resize(heatmap_AU6_1_np,(256,256)) 74 | map_AU6_1 = cv2.applyColorMap(heatmap_AU6_1_rz, cv2.COLORMAP_JET) 75 | map_AU6_1=cv2.cvtColor(map_AU6_1, cv2.COLOR_RGB2BGR) 76 | map_AU6 = map_AU6_0*0.5+map_AU6_1*0.5+img_ori*0.5 77 | cv2.putText(map_AU6,"AU6: "+str(AU06_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 78 | """ 79 | Visualization of the predicted AU10 heatmap. 80 | """ 81 | heatmap_AU10_0 = heatmap[index][2].to('cpu').detach() 82 | heatmap_AU10_0[heatmap_AU10_0255*5.0]=255*5.0 84 | heatmap_AU10_0_np = (heatmap_AU10_0.numpy()/5.0).astype(np.uint8).copy() 85 | heatmap_AU10_0_rz = cv2.resize(heatmap_AU10_0_np,(256,256)) 86 | map_AU10_0 = cv2.applyColorMap(heatmap_AU10_0_rz, cv2.COLORMAP_JET) 87 | map_AU10_0=cv2.cvtColor(map_AU10_0, cv2.COLOR_RGB2BGR) 88 | heatmap_AU10_1 = heatmap[index][3].to('cpu').detach() 89 | heatmap_AU10_1[heatmap_AU10_1255*5.0]=255*5.0 91 | heatmap_AU10_1_np = (heatmap_AU10_1.numpy()/5.0).astype(np.uint8).copy() 92 | heatmap_AU10_1_rz = cv2.resize(heatmap_AU10_1_np,(256,256)) 93 | map_AU10_1 = cv2.applyColorMap(heatmap_AU10_1_rz, cv2.COLORMAP_JET) 94 | map_AU10_1=cv2.cvtColor(map_AU10_1, cv2.COLOR_RGB2BGR) 95 | map_AU10 = map_AU10_0*0.5+map_AU10_1*0.5+img_ori*0.5 96 | cv2.putText(map_AU10,"AU10: "+str(AU10_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 97 | """ 98 | Visualization of the predicted AU12 heatmap. 99 | """ 100 | heatmap_AU12_0 = heatmap[index][4].to('cpu').detach() 101 | heatmap_AU12_0[heatmap_AU12_0255*5.0]=255*5.0 103 | heatmap_AU12_0_np = (heatmap_AU12_0.numpy()/5.0).astype(np.uint8).copy() 104 | heatmap_AU12_0_rz = cv2.resize(heatmap_AU12_0_np,(256,256)) 105 | map_AU12_0 = cv2.applyColorMap(heatmap_AU12_0_rz, cv2.COLORMAP_JET) 106 | map_AU12_0=cv2.cvtColor(map_AU12_0, cv2.COLOR_RGB2BGR) 107 | heatmap_AU12_1 = heatmap[index][5].to('cpu').detach() 108 | heatmap_AU12_1[heatmap_AU12_1255*5.0]=255*5.0 110 | heatmap_AU12_1_np = (heatmap_AU12_1.numpy()/5.0).astype(np.uint8).copy() 111 | heatmap_AU12_1_rz = cv2.resize(heatmap_AU12_1_np,(256,256)) 112 | map_AU12_1 = cv2.applyColorMap(heatmap_AU12_1_rz, cv2.COLORMAP_JET) 113 | map_AU12_1=cv2.cvtColor(map_AU12_1, cv2.COLOR_RGB2BGR) 114 | map_AU12 = map_AU12_0*0.5+map_AU12_1*0.5+img_ori*0.5 115 | cv2.putText(map_AU12,"AU12: "+str(AU12_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 116 | """ 117 | Visualization of the predicted AU14 heatmap. 118 | """ 119 | heatmap_AU14_0 = heatmap[index][6].to('cpu').detach() 120 | heatmap_AU14_0[heatmap_AU14_0255*5.0]=255*5.0 122 | heatmap_AU14_0_np = (heatmap_AU14_0.numpy()/5.0).astype(np.uint8).copy() 123 | heatmap_AU14_0_rz = cv2.resize(heatmap_AU14_0_np,(256,256)) 124 | map_AU14_0 = cv2.applyColorMap(heatmap_AU14_0_rz, cv2.COLORMAP_JET) 125 | map_AU14_0=cv2.cvtColor(map_AU14_0, cv2.COLOR_RGB2BGR) 126 | heatmap_AU14_1 = heatmap[index][7].to('cpu').detach() 127 | heatmap_AU14_1[heatmap_AU14_1255*5.0]=255*5.0 129 | heatmap_AU14_1_np = (heatmap_AU14_1.numpy()/5.0).astype(np.uint8).copy() 130 | heatmap_AU14_1_rz = cv2.resize(heatmap_AU14_1_np,(256,256)) 131 | map_AU14_1 = cv2.applyColorMap(heatmap_AU14_1_rz, cv2.COLORMAP_JET) 132 | map_AU14_1=cv2.cvtColor(map_AU14_1, cv2.COLOR_RGB2BGR) 133 | map_AU14 = map_AU14_0*0.5+map_AU14_1*0.5+img_ori*0.5 134 | cv2.putText(map_AU14,"AU14: "+str(AU14_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 135 | """ 136 | Visualization of the predicted AU17 heatmap. 137 | """ 138 | heatmap_AU17_0 = heatmap[index][8].to('cpu').detach() 139 | heatmap_AU17_0[heatmap_AU17_0255*5.0]=255*5.0 141 | heatmap_AU17_0_np = (heatmap_AU17_0.numpy()/5.0).astype(np.uint8).copy() 142 | heatmap_AU17_0_rz = cv2.resize(heatmap_AU17_0_np,(256,256)) 143 | map_AU17_0 = cv2.applyColorMap(heatmap_AU17_0_rz, cv2.COLORMAP_JET) 144 | map_AU17_0=cv2.cvtColor(map_AU17_0, cv2.COLOR_RGB2BGR) 145 | heatmap_AU17_1 = heatmap[index][9].to('cpu').detach() 146 | heatmap_AU17_1[heatmap_AU17_1255*5.0]=255*5.0 148 | heatmap_AU17_1_np = (heatmap_AU17_1.numpy()/5.0).astype(np.uint8).copy() 149 | heatmap_AU17_1_rz = cv2.resize(heatmap_AU17_1_np,(256,256)) 150 | map_AU17_1 = cv2.applyColorMap(heatmap_AU17_1_rz, cv2.COLORMAP_JET) 151 | map_AU17_1=cv2.cvtColor(map_AU17_1, cv2.COLOR_RGB2BGR) 152 | map_AU17 = map_AU17_0*0.5+map_AU17_1*0.5+img_ori*0.5 153 | cv2.putText(map_AU17,"AU17: "+str(AU17_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 154 | 155 | if images is None: 156 | images = np.expand_dims(img_ori,axis=0) 157 | else: 158 | images = np.concatenate((images, np.expand_dims(img_ori,axis=0))) 159 | 160 | if maps_AU6 is None: 161 | maps_AU6 = np.expand_dims(map_AU6,axis=0) 162 | else: 163 | maps_AU6 = np.concatenate((maps_AU6, np.expand_dims(map_AU6,axis=0))) 164 | if maps_AU10 is None: 165 | maps_AU10 = np.expand_dims(map_AU10,axis=0) 166 | else: 167 | maps_AU10 = np.concatenate((maps_AU10, np.expand_dims(map_AU10,axis=0))) 168 | if maps_AU12 is None: 169 | maps_AU12 = np.expand_dims(map_AU12,axis=0) 170 | else: 171 | maps_AU12 = np.concatenate((maps_AU12, np.expand_dims(map_AU12,axis=0))) 172 | if maps_AU14 is None: 173 | maps_AU14 = np.expand_dims(map_AU14,axis=0) 174 | else: 175 | maps_AU14 = np.concatenate((maps_AU14, np.expand_dims(map_AU14,axis=0))) 176 | if maps_AU17 is None: 177 | maps_AU17 = np.expand_dims(map_AU17,axis=0) 178 | else: 179 | maps_AU17 = np.concatenate((maps_AU17, np.expand_dims(map_AU17,axis=0))) 180 | 181 | # Save the visualized AU heatmaps in path "./visualize/" 182 | if not os.path.exists('./visualize/'): 183 | os.makedirs('./visualize/') 184 | 185 | save_AU6 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU6/255.0).permute(0,3,1,2),scale_factor=0.5) 186 | save_image(save_AU6, './visualize/Subject{}_AU06.png'.format(count)) 187 | save_AU10 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU10/255.0).permute(0,3,1,2),scale_factor=0.5) 188 | save_image(save_AU10, './visualize/Subject{}_AU10.png'.format(count)) 189 | save_AU12 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU12/255.0).permute(0,3,1,2),scale_factor=0.5) 190 | save_image(save_AU12, './visualize/Subject{}_AU12.png'.format(count)) 191 | save_AU14 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU14/255.0).permute(0,3,1,2),scale_factor=0.5) 192 | save_image(save_AU14, './visualize/Subject{}_AU14.png'.format(count)) 193 | save_AU17 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU17/255.0).permute(0,3,1,2),scale_factor=0.5) 194 | save_image(save_AU17, './visualize/Subject{}_AU17.png'.format(count)) 195 | count += 1 196 | 197 | y_labels = [AU06_intensity, AU10_intensity, AU12_intensity, AU14_intensity, AU17_intensity] 198 | all_intensities.append(y_labels) 199 | plt.plot(x_labels, y_labels, label = "Subject"+str(count-1)) 200 | plt.ylabel("FAU intensity") 201 | plt.title("Intensity variation across FAUs") 202 | plt.legend() 203 | plt.show() 204 | print("FAU17 intensities for all subjects: ", au17_all_subjects) 205 | all_intensities = np.array(all_intensities) 206 | plot(all_intensities, x_labels) 207 | 208 | return np.concatenate(preds) 209 | 210 | def plot(all_intensities, x_labels): 211 | 212 | x = ["Subject0", "Subject1", "Subject2", "Subject3", "Subject4"] 213 | for i in range(5): 214 | plt.plot(x, all_intensities[:,i], label = x_labels[i]) 215 | plt.ylabel("FAU intensity") 216 | plt.title("FAU intensity variation across subjects") 217 | plt.legend() 218 | plt.show() 219 | 220 | return 221 | 222 | 223 | def test_epoch( dataset_test, model_path,size, npoints): 224 | net = loadnet(npoints,model_path) 225 | OUT = OutIntensity().to('cpu') 226 | # Load data 227 | database = MyDatasets(size=size,database=dataset_test) 228 | dbloader = DataLoader(database, batch_size=1, shuffle=False, num_workers=0, pin_memory=False) 229 | pred = predict(dbloader,OUT,net) 230 | 231 | def main(): 232 | global args 233 | args = parser.parse_args() 234 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda) 235 | test_epoch(dataset_test=args.dataset_test,model_path=args.model_path,size=args.size,npoints=args.K) 236 | 237 | if __name__ == '__main__': 238 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | from load_data import MyDatasets 3 | from utils import * 4 | from tqdm import tqdm 5 | from network import ResNet 6 | from torch.utils.data import Dataset, DataLoader 7 | import os, pickle 8 | import time 9 | from torchvision.utils import save_image 10 | import argparse 11 | from PIL import Image 12 | import matplotlib.pyplot as plt 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--K', default=10, help='Number of AU positions')#24,10 16 | parser.add_argument('--dataset', default='BP4D', type=str, help='database')#BP4D,DISFA 17 | parser.add_argument('--dataset_test', default='train', type=str)#BP4D-val, DISFA-val 18 | parser.add_argument('--model_path', type=str,default='./model/model.pth', help='model path') 19 | parser.add_argument('--cuda', default='5', type=str, help='cuda') 20 | parser.add_argument('--size', default=256, help='Image size') 21 | parser.add_argument('--lr', default=1e-5, help='Learning rate') 22 | parser.add_argument('--epochs', default=10, help='Epochs') 23 | 24 | def get_true_map(heatmap): 25 | true_map = heatmap.detach() / 256 #10x64x64 26 | label = torch.zeros(true_map.shape) #10x64x64 27 | for j in range(0, true_map.shape[0]): 28 | temp = cv2.applyColorMap((np.float32(true_map[j,:,:])).astype(np.uint8), cv2.COLORMAP_JET) #64x64 29 | label[j, :, :] = torch.FloatTensor(temp).mean(-1).unsqueeze(0) #1x64x64 30 | return label #10x64x64 31 | 32 | def loadnet(npoints=10,path_to_model=None): 33 | # Load the trained model. 34 | net = ResNet(num_maps=npoints) 35 | checkpoint = torch.load(path_to_model, map_location='cpu') 36 | checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()} 37 | net.load_state_dict(checkpoint,strict=False) 38 | return net.to('cpu') 39 | 40 | def plot_loss(loss): 41 | xax = np.arange(1,len(loss)+1) 42 | plt.plot(xax, loss) 43 | plt.xlabel("epochs") 44 | plt.ylabel("Loss") 45 | plt.title("Loss vs Number of epochs") 46 | plt.xticks(np.arange(1,len(loss)+1)) 47 | plt.show() 48 | return 49 | 50 | 51 | 52 | def train(loader,OUT,net, sample_len): 53 | lossy = [] 54 | mean_loss = torch.zeros((sample_len,1)) 55 | opt = torch.optim.Adam(net.parameters(), lr=args.lr) 56 | layer = 0 57 | for child in net.children(): 58 | layer += 1 59 | if layer < 9: 60 | for param in child.parameters(): 61 | param.requires_grad = True 62 | net.train() 63 | for i in tqdm(range(args.epochs)): 64 | print('\tEpochs: ', i,'/', args.epochs,'\tTime:', time.time() - start_time, 'sec') 65 | for idx, sample in enumerate(loader): 66 | opt.zero_grad() 67 | img = sample['Im'] 68 | heatmap = net(img).squeeze(0) 69 | true_label = get_true_map(heatmap) 70 | loss = (heatmap - true_label).pow(2).mean() 71 | loss.backward() 72 | opt.step() 73 | mean_loss[idx] = loss.item() 74 | print('Loss- ', (torch.mean(mean_loss)).item()) 75 | lossy.append(torch.mean(mean_loss).item()) 76 | plot_loss(lossy) 77 | torch.save(net.state_dict(), './model/tuned_model.pth') 78 | 79 | def test_epoch( dataset_test, model_path,size, npoints): 80 | net = loadnet(npoints,model_path) 81 | OUT = OutIntensity().to('cpu') 82 | # Load data 83 | database = MyDatasets(size=size,database=dataset_test) 84 | dbloader = DataLoader(database, batch_size=1, shuffle=False, num_workers=0, pin_memory=False) 85 | train(dbloader,OUT,net, len(database)) 86 | 87 | def main(): 88 | global args 89 | global start_time 90 | start_time = time.time() 91 | args = parser.parse_args() 92 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda) 93 | test_epoch(dataset_test=args.dataset_test,model_path=args.model_path,size=args.size,npoints=args.K) 94 | 95 | if __name__ == '__main__': 96 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch, math 2 | import cv2 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | class OutIntensity(torch.nn.Module): 8 | """Infer AU intensity from a heatmap: :(x, y) = argmax H """ 9 | def __init__(self): 10 | super(OutIntensity,self).__init__() 11 | 12 | def forward(self,x): 13 | batch_size = x.shape[0] 14 | num_points = x.shape[1] 15 | width = x.shape[2] 16 | x_ = x.to('cpu').detach().numpy().astype(np.float32).copy() 17 | heatmaps_reshaped = x_.reshape((batch_size, num_points, -1)) 18 | intensity = heatmaps_reshaped.max(axis=2) 19 | return intensity/255.0 20 | --------------------------------------------------------------------------------