├── README.md
├── bottleneck.PNG
├── deconv.PNG
├── load_data.py
├── network.py
├── procrustes.py
├── run_demo.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Facial Action Unit Detection
2 |
3 | Facial Action Units or FAUs are specific regions of the face which allow one to estimate the intensity and type of emotion one is experiencing in an image. Each facial movement in the FAU coding system is encoded by its appearance on the face. AU coding starts from 0 and ends at 28 with 0th code representing a neutral face. While there are various FAUs present in the FAU coding system most applications use only a few of them to characterize emotions. These FAUs are commonly referred to as Main FAUs. For instance, the AU code for *cheek raise* (AU6) and *lip corner pull* (AU12) when combined together give the *happiness* (AU6+12) emotion. As another example, the AU codes *inner brow raise* (AU1), *brow low* (AU4) and *lip corner depress* (AU15) together give the *sadness* (AU1+4+15) emotion. AU codes start from 0 and end at 28 with the 0th code corresponding to *neutral face*. A complete list of AU codes and their corresponding facial expressions can be found on [Wikipedia](https://en.wikipedia.org/wiki/Facial_Action_Coding_System).
4 |
5 | A slight modification of the deep ResNet architecture pretrained on BP4D FAU dataset is utilized to predict the intensity of the facial expressions. The intensity labels for each AU range from 0 to 5 on a discrete integer scale.
6 |
7 | The repository contains following files-
8 | * `train.py`, `nerwork.py` and `run_demo.py` - These files are used to implement ResNet and train our model on a sample data for FAU Intensity prediction.
9 | * `procrustes.py`- For performing [procrustes analysis](https://bmvc2019.org/wp-content/uploads/papers/0403-paper.pdf) (facial alignment).
10 | * [model.pth](https://drive.google.com/file/d/1XYCoLtApTHq89_s3gf6Lh6U2kMP80oCH/view?usp=sharing) - Trained model on BP4D data
11 |
12 |


13 |
14 | Figure (left) presents modified the ResNet architecture consisting of bottleneck residual blocks and deconvolutional layers. The ResNet module is contained in the `ResNet` class which consists of `_make_deconv_layer` function. The `_make_deconv_layer` constructs a deconvolutional block by omitting the residual connections and replacing convolutions with deconvolutions (`ConvTranspose2d`). Figure (right) presents a single deconvolutional block consisting of 3 sub-blocks each having `ConvTranspose2d`, BatchNorm and ReLU layers.
15 |
--------------------------------------------------------------------------------
/bottleneck.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Siddhantmest/Facial-Action-Unit-Detection/691fcdbc181fa45f64726fec2bd5b7117f4ed860/bottleneck.PNG
--------------------------------------------------------------------------------
/deconv.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Siddhantmest/Facial-Action-Unit-Detection/691fcdbc181fa45f64726fec2bd5b7117f4ed860/deconv.PNG
--------------------------------------------------------------------------------
/load_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Load the BP4D or the DISFA dataset.
3 | """
4 | import os
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import inspect
9 | from torch.utils.data import Dataset
10 |
11 | class MyDatasets(Dataset):
12 | def __init__(self, sigma=2, size=256, heatmap=32,AU_positions=10, database=''):
13 | if database == 'train':
14 | txt_file = open('./data/examples.txt','r')
15 | if database == 'demo':
16 | txt_file = open('./test/examples.txt','r')
17 | lines = txt_file.readlines()[0::]
18 | names = [l.split()[0] for l in lines]
19 | coords = [l.split()[1::] for l in lines]
20 | self.database = database
21 | self.data = dict(zip(names,coords))
22 | self.imgs = list(set(names))
23 | self.len = len(self.imgs)
24 |
25 | def generate_target(self, points, intensity):
26 | target = np.zeros((self.AU_positions,self.heatmap,self.heatmap),dtype=np.float32)
27 | gs_range = self.sigma * 15
28 | for point_id in range(self.AU_positions):
29 | feat_stride = self.size / self.heatmap
30 | mu_x = int(points[point_id][0] / feat_stride + 0.5)
31 | mu_y = int(points[point_id][1] / feat_stride + 0.5)
32 | ul = [int(mu_x - gs_range), int(mu_y - gs_range)]
33 | br = [int(mu_x + gs_range + 1), int(mu_y + gs_range + 1)]
34 | x = np.arange(0, 2*gs_range+1, 1, np.float32)
35 | y = x[:, np.newaxis]
36 | x_center = y_center = (2*gs_range+1) // 2
37 | g = np.exp(- ((x - x_center) ** 2 + (y - y_center) ** 2) / (2 * self.sigma ** 2))
38 | g_x = max(0, -ul[0]), min(br[0], self.heatmap) - ul[0]
39 | g_y = max(0, -ul[1]), min(br[1], self.heatmap) - ul[1]
40 | img_x = max(0, ul[0]), min(br[0], self.heatmap)
41 | img_y = max(0, ul[1]), min(br[1], self.heatmap)
42 | target[point_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = intensity[point_id]*g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
43 | return target*255.0
44 |
45 | def fetch(self,index):
46 | path_to_img = self.imgs[index]
47 | image = cv2.cvtColor(cv2.imread(path_to_img), cv2.COLOR_BGR2RGB)
48 | if self.database == 'demo':
49 | return image, [0], [0]
50 | AUs = self.data[self.imgs[index]]
51 | AUs = np.float32(self.data[self.imgs[index]]).reshape(-1,3)
52 | AU_coords = AUs[:,:2]
53 | AU_intensity = AUs[:,2]
54 | return image, AU_coords, AU_intensity
55 |
56 | def __getitem__(self,index):
57 | image, AU_coords, AU_intensity = self.fetch(index)
58 | nimg = len(image)
59 | sample = dict.fromkeys(['Im'], None)
60 | out = dict.fromkeys(['image','points'])
61 | image_np = torch.from_numpy((image/255.0).swapaxes(2,1).swapaxes(1,0))
62 | out['image'] = image_np.type_as(torch.FloatTensor())
63 | out['AU_coords'] = np.floor(AU_coords)
64 | if self.database not in ['demo','train']:
65 | target = self.generate_target(out['AU_coords'], AU_intensity)
66 | target = torch.from_numpy(target).type_as(torch.FloatTensor())
67 | sample['target'] = target
68 | sample['pts'] = out['AU_coords']
69 | sample['intensity'] = AU_intensity
70 | sample['Im'] = out['image']
71 | return sample
72 |
73 | def __len__(self):
74 | return len(self.imgs)
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import logging
7 | import math
8 | import torch
9 | import torch.nn as nn
10 |
11 | BN_MOMENTUM = 0.1 # momentum parameter for BatchNorm
12 | logger = logging.getLogger(__name__)
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | return nn.Conv2d(
16 | in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False
18 | )
19 |
20 | class Bottleneck(nn.Module):
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None):
24 | super(Bottleneck, self).__init__()
25 | self.downsample = downsample
26 | self.stride = stride
27 | self.momentum = BN_MOMENTUM # momentum parameter for BatchNorm
28 |
29 | # use for bottleneck block
30 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
31 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
32 | self.conv2 = conv3x3(planes, planes, stride=stride)
33 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
34 | self.conv3 = nn.Conv2d(planes, planes*self.expansion, kernel_size=1, stride=1, bias=False)
35 | self.bn3 = nn.BatchNorm2d(planes*self.expansion, momentum=BN_MOMENTUM)
36 | self.relu = nn.ReLU()
37 |
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | # use for bottleneck block
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv3(out)
52 | out = self.bn3(out)
53 |
54 |
55 | if self.downsample is not None:
56 | residual = self.downsample(x)
57 |
58 | out += residual
59 | out = self.relu(out)
60 | return out
61 |
62 | class ResNet(nn.Module):
63 | #Res-50 Bottleneck,[3,4,6,3]
64 | def __init__(self, block=Bottleneck, num_maps=10, layers=[3,4,6,3]):
65 | self.inplanes = 64
66 | self.deconv_with_bias = False
67 |
68 | super(ResNet, self).__init__()
69 |
70 |
71 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
72 | bias=False)
73 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
76 | self.layer1 = self._make_layer(block, 64, layers[0])
77 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
78 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
79 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
80 |
81 | # use for deconv blocks
82 | self.deconv_layers_1 = self._make_deconv_layer(0,[256, 256, 256],[4, 4, 4])
83 | self.deconv_layers_2 = self._make_deconv_layer(0,[256, 256, 256],[4, 4, 4])
84 | self.deconv_layers_3 = self._make_deconv_layer(0,[256, 256, 256],[4, 4, 4])
85 |
86 |
87 | self.final_layer = nn.Conv2d(
88 | in_channels=256,
89 | out_channels=num_maps,
90 | kernel_size=1,
91 | stride=1,
92 | padding=0
93 | )
94 |
95 | self._initialize_weights()
96 |
97 | def _make_layer(self, block, planes, blocks, stride=1):
98 | downsample = None
99 | if stride != 1 or self.inplanes != planes * block.expansion:
100 | downsample = nn.Sequential(
101 | nn.Conv2d(self.inplanes, planes * block.expansion,
102 | kernel_size=1, stride=stride, bias=False),
103 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
104 | )
105 |
106 | layers = []
107 | layers.append(block(self.inplanes, planes, stride, downsample))
108 | self.inplanes = planes * block.expansion
109 | for i in range(1, blocks):
110 | layers.append(block(self.inplanes, planes))
111 |
112 | return nn.Sequential(*layers)
113 |
114 | def _get_deconv_cfg(self, deconv_kernel, index):
115 | if deconv_kernel == 4:
116 | padding = 1
117 | output_padding = 0
118 | elif deconv_kernel == 3:
119 | padding = 1
120 | output_padding = 1
121 | elif deconv_kernel == 2:
122 | padding = 0
123 | output_padding = 0
124 |
125 | return deconv_kernel, padding, output_padding
126 |
127 | def _make_deconv_layer(self, i , num_filters, num_kernels):
128 |
129 | layers = []
130 |
131 | kernel, padding, output_padding = \
132 | self._get_deconv_cfg(num_kernels[i], i)
133 |
134 | planes = num_filters[i]
135 | layers.append(
136 | nn.ConvTranspose2d(
137 | in_channels=self.inplanes,
138 | out_channels=planes,
139 | kernel_size=kernel,
140 | stride=2,
141 | padding=padding,
142 | output_padding=output_padding,
143 | bias=self.deconv_with_bias))
144 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
145 | layers.append(nn.ReLU(inplace=True))
146 | self.inplanes = planes
147 |
148 | return nn.Sequential(*layers)
149 |
150 | def forward(self, x):
151 | x = self.conv1(x)
152 | x = self.bn1(x)
153 | x = self.relu(x)
154 | x = self.maxpool(x)
155 |
156 | x = self.layer1(x)
157 | x = self.layer2(x)
158 | x = self.layer3(x)
159 | x = self.layer4(x)
160 |
161 | # use for deconv blocks
162 | x = self.deconv_layers_1(x)
163 | x = self.deconv_layers_2(x)
164 | x = self.deconv_layers_3(x)
165 |
166 | x = self.final_layer(x)
167 |
168 | return x
169 |
170 | def _initialize_weights(self):
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
174 | m.weight.data.normal_(0, math.sqrt(2. / n))
175 | if m.bias is not None:
176 | m.bias.data.zero_()
177 | elif isinstance(m, nn.BatchNorm2d):
178 | m.weight.data.fill_(1)
179 | m.bias.data.zero_()
180 | elif isinstance(m, nn.Linear):
181 | n = m.weight.size(1)
182 | m.weight.data.normal_(0, 0.01)
183 | m.bias.data.zero_()
184 |
185 |
--------------------------------------------------------------------------------
/procrustes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dlib
3 | import matplotlib.pyplot as plt
4 | import imutils
5 | import cv2
6 | import os
7 |
8 | PREDICTOR_PATH = "./data/shape_predictor_68_face_landmarks.dat"
9 |
10 | txt_file = open('./test2/examples.txt','r')
11 | lines = txt_file.readlines()
12 | new_paths = []
13 | for ps in lines:
14 | new_paths.append(ps.replace('\r', '').replace('\n', ''))
15 | ref_path = new_paths[1]
16 | new_paths.remove(ref_path)
17 | file_path = './test2_1'
18 |
19 |
20 | def procrustes_analysys(A, B):
21 | """Procrustes analysis
22 | Basic algorithm is
23 | 1. Recenter the points based on their mean: compute a mean and subtract it from every points in shape
24 | 2. Normalize
25 | 3. Rotate one of the shapes and find MSE
26 | Args:
27 | A:
28 | B:
29 | Returns:
30 | """
31 | h_A, w_A = A.shape
32 | h_B, w_B = B.shape
33 |
34 | # compute mean of each A and B
35 | Amu = np.mean(A, axis=0)
36 | Bmu = np.mean(B, axis=0)
37 |
38 | # subtract a mean
39 | A_base = A - Amu
40 | B_base = B - Bmu
41 |
42 | # normalize
43 | ssum_A = (A_base**2).sum()
44 | ssum_B = (B_base**2).sum()
45 |
46 | norm_A = np.sqrt(ssum_A)
47 | norm_B = np.sqrt(ssum_B)
48 |
49 | normalized_A = A_base / norm_A
50 | normalized_B = B_base / norm_B
51 |
52 | if (w_B < w_A):
53 | normalized_B = np.concatenate((normalized_B, np.zeros(h_A, w_A - w_B)), 0)
54 |
55 | A = np.dot(normalized_A.T, normalized_B)
56 |
57 | # SVD
58 | u, s, vh = np.linalg.svd(A, full_matrices=False)
59 | v = vh.T
60 | T = np.dot(v, u.T)
61 |
62 | scale = norm_A / norm_B
63 |
64 | return T, scale
65 |
66 | def shape_to_np(shape, dtype="int"):
67 | """Take a shape object and convert it to numpy array
68 | Args:
69 | shape: an object returned by dlib face landmark detector containing the 68 (x, y)-coordinates of the facial landmark regions
70 | dtype: int
71 | Returns:
72 | coords: (68,2) numpy array
73 | """
74 | coords = np.zeros((68, 2), dtype=dtype)
75 |
76 | for i in range(0, 68):
77 | coords[i] = (shape.part(i).x, shape.part(i).y)
78 |
79 | return coords
80 |
81 | def get_face(img1_detection):
82 | for face in img1_detection:
83 | x = face.left()
84 | y = face.top()
85 | w = face.right() - x
86 | h = face.bottom() - y
87 |
88 | # draw box over face
89 | cv2.rectangle(img1, (x,y), (x+w,y+h), (0,255,0), 2)
90 |
91 | img_height, img_width = img1.shape[:2]
92 | cv2.putText(img1, "HOG", (img_width-50,20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
93 | (0,255,0), 2)
94 |
95 | # display output image
96 | plt.imshow(img1)
97 |
98 | def plot_landmarks():
99 | # plot facial landmarks on the image
100 | for (x, y) in img1_shape:
101 | cv2.circle(img1, (x, y), 1, (0, 0, 255), -1)
102 |
103 | plt.imshow(img1)
104 |
105 |
106 | detector = dlib.get_frontal_face_detector()
107 | predictor = dlib.shape_predictor(PREDICTOR_PATH)
108 |
109 | for path in new_paths:
110 | img1 = dlib.load_rgb_image(ref_path)
111 | img2 = dlib.load_rgb_image(path)
112 |
113 | img1_detection = detector(img1, 1)
114 | img2_detection = detector(img2, 1)
115 |
116 | img1_shape = predictor(img1, img1_detection[0])
117 | img2_shape = predictor(img2, img2_detection[0])
118 |
119 | img1_shape1 = shape_to_np(img1_shape)
120 | img2_shape1 = shape_to_np(img2_shape)
121 |
122 | M, scale = procrustes_analysys(img1_shape1, img2_shape1)
123 | theta = np.rad2deg(np.arccos(M[0][0]))
124 | #print("theta is {}".format(theta))
125 |
126 | rotation_matrix = cv2.getRotationMatrix2D((img1.shape[1]/2, img1.shape[0]/2), theta, 1)
127 | dst = cv2.warpAffine(img2, rotation_matrix, (img2.shape[1], img2.shape[0]))
128 |
129 |
130 | img2_aligned = dlib.get_face_chip(dst, img2_shape)
131 |
132 |
133 | img2_aligned_resized = imutils.resize(img2_aligned, width = img2.shape[0], height = img2.shape[1])
134 |
135 |
136 | img_name = path.split('/')[-1].split(".")[0] + '.jpg'
137 | cv2.imwrite(os.path.join(file_path , img_name), img2_aligned_resized)
138 |
139 | if path == new_paths[-1]:
140 | img1_aligned = dlib.get_face_chip(img1, img1_shape)
141 | img1_aligned_resized = imutils.resize(img1_aligned, width = img1.shape[0], height = img1.shape[1])
142 | img_name = ref_path.split('/')[-1].split(".")[0] + '.jpg'
143 | cv2.imwrite(os.path.join(file_path , img_name), img1_aligned_resized)
--------------------------------------------------------------------------------
/run_demo.py:
--------------------------------------------------------------------------------
1 | import torch, numpy as np
2 | from load_data import MyDatasets
3 | from utils import *
4 | from network import ResNet
5 | from torch.utils.data import Dataset, DataLoader
6 | import os, pickle
7 | from torchvision.utils import save_image
8 | import argparse
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--K', default=10, help='Number of AU positions')#24,10
14 | parser.add_argument('--dataset', default='BP4D', type=str, help='database')#BP4D,DISFA
15 | parser.add_argument('--dataset_test', default='demo', type=str)#BP4D-val, DISFA-val
16 | parser.add_argument('--model_path', type=str,default='./model/model.pth', help='model path') #model.pth
17 | parser.add_argument('--cuda', default='10', type=str, help='cuda')
18 | parser.add_argument('--size', default=256, help='Image size')
19 |
20 | def loadnet(npoints=10,path_to_model=None):
21 | # Load the trained model.
22 | net = ResNet(num_maps=npoints)
23 | checkpoint = torch.load(path_to_model, map_location='cpu')
24 | checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()}
25 | net.load_state_dict(checkpoint,strict=False)
26 | return net.to('cpu')
27 |
28 | def predict(loader,OUT,net):
29 | preds = []
30 | au17_all_subjects = []
31 | x_labels = ["FAU06", "FAU10", "FAU12", "FAU14", "FAU17"]
32 | all_intensities = []
33 | with torch.no_grad():
34 | count = 0
35 | for sample in loader:
36 | img = sample['Im']
37 | heatmap = net(img)
38 | out = OUT(heatmap)
39 | preds.append(out)
40 | images = None
41 | maps_AU6 = None
42 | maps_AU10 = None
43 | maps_AU12= None
44 | maps_AU14= None
45 | maps_AU17= None
46 | threshold = 0.1
47 | font = cv2.FONT_HERSHEY_SIMPLEX
48 |
49 | for (index,item) in enumerate(sample['Im'].to('cpu').detach()):
50 | img_ori = (255*item.permute(1,2,0).numpy()).astype(np.uint8).copy()
51 | AU_intensities = out[index]
52 | AU06_intensity = round((out[index][0]+out[index][1])/2.0,2)
53 | AU10_intensity = round((out[index][2]+out[index][3])/2.0,2)
54 | AU12_intensity = round((out[index][4]+out[index][5])/2.0,2)
55 | AU14_intensity = round((out[index][6]+out[index][7])/2.0,2)
56 | AU17_intensity = round((out[index][8]+out[index][9])/2.0,2)
57 |
58 | au17_all_subjects.append(AU17_intensity)
59 | """
60 | Visualization of the predicted AU6 heatmap.
61 | """
62 | heatmap_AU6_0 = heatmap[index][0].to('cpu').detach()
63 | heatmap_AU6_0[heatmap_AU6_0255*5.0]=255*5.0
65 | heatmap_AU6_0_np = (heatmap_AU6_0.numpy()/5.0).astype(np.uint8).copy()
66 | heatmap_AU6_0_rz = cv2.resize(heatmap_AU6_0_np,(256,256))
67 | map_AU6_0 = cv2.applyColorMap(heatmap_AU6_0_rz, cv2.COLORMAP_JET)
68 | map_AU6_0=cv2.cvtColor(map_AU6_0, cv2.COLOR_RGB2BGR)
69 | heatmap_AU6_1 = heatmap[index][1].to('cpu').detach()
70 | heatmap_AU6_1[heatmap_AU6_1255*5.0]=255*5.0
72 | heatmap_AU6_1_np = (heatmap_AU6_1.numpy()/5.0).astype(np.uint8).copy()
73 | heatmap_AU6_1_rz = cv2.resize(heatmap_AU6_1_np,(256,256))
74 | map_AU6_1 = cv2.applyColorMap(heatmap_AU6_1_rz, cv2.COLORMAP_JET)
75 | map_AU6_1=cv2.cvtColor(map_AU6_1, cv2.COLOR_RGB2BGR)
76 | map_AU6 = map_AU6_0*0.5+map_AU6_1*0.5+img_ori*0.5
77 | cv2.putText(map_AU6,"AU6: "+str(AU06_intensity), (5, 25), font, 0.85, (255, 255, 255), 2)
78 | """
79 | Visualization of the predicted AU10 heatmap.
80 | """
81 | heatmap_AU10_0 = heatmap[index][2].to('cpu').detach()
82 | heatmap_AU10_0[heatmap_AU10_0255*5.0]=255*5.0
84 | heatmap_AU10_0_np = (heatmap_AU10_0.numpy()/5.0).astype(np.uint8).copy()
85 | heatmap_AU10_0_rz = cv2.resize(heatmap_AU10_0_np,(256,256))
86 | map_AU10_0 = cv2.applyColorMap(heatmap_AU10_0_rz, cv2.COLORMAP_JET)
87 | map_AU10_0=cv2.cvtColor(map_AU10_0, cv2.COLOR_RGB2BGR)
88 | heatmap_AU10_1 = heatmap[index][3].to('cpu').detach()
89 | heatmap_AU10_1[heatmap_AU10_1255*5.0]=255*5.0
91 | heatmap_AU10_1_np = (heatmap_AU10_1.numpy()/5.0).astype(np.uint8).copy()
92 | heatmap_AU10_1_rz = cv2.resize(heatmap_AU10_1_np,(256,256))
93 | map_AU10_1 = cv2.applyColorMap(heatmap_AU10_1_rz, cv2.COLORMAP_JET)
94 | map_AU10_1=cv2.cvtColor(map_AU10_1, cv2.COLOR_RGB2BGR)
95 | map_AU10 = map_AU10_0*0.5+map_AU10_1*0.5+img_ori*0.5
96 | cv2.putText(map_AU10,"AU10: "+str(AU10_intensity), (5, 25), font, 0.85, (255, 255, 255), 2)
97 | """
98 | Visualization of the predicted AU12 heatmap.
99 | """
100 | heatmap_AU12_0 = heatmap[index][4].to('cpu').detach()
101 | heatmap_AU12_0[heatmap_AU12_0255*5.0]=255*5.0
103 | heatmap_AU12_0_np = (heatmap_AU12_0.numpy()/5.0).astype(np.uint8).copy()
104 | heatmap_AU12_0_rz = cv2.resize(heatmap_AU12_0_np,(256,256))
105 | map_AU12_0 = cv2.applyColorMap(heatmap_AU12_0_rz, cv2.COLORMAP_JET)
106 | map_AU12_0=cv2.cvtColor(map_AU12_0, cv2.COLOR_RGB2BGR)
107 | heatmap_AU12_1 = heatmap[index][5].to('cpu').detach()
108 | heatmap_AU12_1[heatmap_AU12_1255*5.0]=255*5.0
110 | heatmap_AU12_1_np = (heatmap_AU12_1.numpy()/5.0).astype(np.uint8).copy()
111 | heatmap_AU12_1_rz = cv2.resize(heatmap_AU12_1_np,(256,256))
112 | map_AU12_1 = cv2.applyColorMap(heatmap_AU12_1_rz, cv2.COLORMAP_JET)
113 | map_AU12_1=cv2.cvtColor(map_AU12_1, cv2.COLOR_RGB2BGR)
114 | map_AU12 = map_AU12_0*0.5+map_AU12_1*0.5+img_ori*0.5
115 | cv2.putText(map_AU12,"AU12: "+str(AU12_intensity), (5, 25), font, 0.85, (255, 255, 255), 2)
116 | """
117 | Visualization of the predicted AU14 heatmap.
118 | """
119 | heatmap_AU14_0 = heatmap[index][6].to('cpu').detach()
120 | heatmap_AU14_0[heatmap_AU14_0255*5.0]=255*5.0
122 | heatmap_AU14_0_np = (heatmap_AU14_0.numpy()/5.0).astype(np.uint8).copy()
123 | heatmap_AU14_0_rz = cv2.resize(heatmap_AU14_0_np,(256,256))
124 | map_AU14_0 = cv2.applyColorMap(heatmap_AU14_0_rz, cv2.COLORMAP_JET)
125 | map_AU14_0=cv2.cvtColor(map_AU14_0, cv2.COLOR_RGB2BGR)
126 | heatmap_AU14_1 = heatmap[index][7].to('cpu').detach()
127 | heatmap_AU14_1[heatmap_AU14_1255*5.0]=255*5.0
129 | heatmap_AU14_1_np = (heatmap_AU14_1.numpy()/5.0).astype(np.uint8).copy()
130 | heatmap_AU14_1_rz = cv2.resize(heatmap_AU14_1_np,(256,256))
131 | map_AU14_1 = cv2.applyColorMap(heatmap_AU14_1_rz, cv2.COLORMAP_JET)
132 | map_AU14_1=cv2.cvtColor(map_AU14_1, cv2.COLOR_RGB2BGR)
133 | map_AU14 = map_AU14_0*0.5+map_AU14_1*0.5+img_ori*0.5
134 | cv2.putText(map_AU14,"AU14: "+str(AU14_intensity), (5, 25), font, 0.85, (255, 255, 255), 2)
135 | """
136 | Visualization of the predicted AU17 heatmap.
137 | """
138 | heatmap_AU17_0 = heatmap[index][8].to('cpu').detach()
139 | heatmap_AU17_0[heatmap_AU17_0255*5.0]=255*5.0
141 | heatmap_AU17_0_np = (heatmap_AU17_0.numpy()/5.0).astype(np.uint8).copy()
142 | heatmap_AU17_0_rz = cv2.resize(heatmap_AU17_0_np,(256,256))
143 | map_AU17_0 = cv2.applyColorMap(heatmap_AU17_0_rz, cv2.COLORMAP_JET)
144 | map_AU17_0=cv2.cvtColor(map_AU17_0, cv2.COLOR_RGB2BGR)
145 | heatmap_AU17_1 = heatmap[index][9].to('cpu').detach()
146 | heatmap_AU17_1[heatmap_AU17_1255*5.0]=255*5.0
148 | heatmap_AU17_1_np = (heatmap_AU17_1.numpy()/5.0).astype(np.uint8).copy()
149 | heatmap_AU17_1_rz = cv2.resize(heatmap_AU17_1_np,(256,256))
150 | map_AU17_1 = cv2.applyColorMap(heatmap_AU17_1_rz, cv2.COLORMAP_JET)
151 | map_AU17_1=cv2.cvtColor(map_AU17_1, cv2.COLOR_RGB2BGR)
152 | map_AU17 = map_AU17_0*0.5+map_AU17_1*0.5+img_ori*0.5
153 | cv2.putText(map_AU17,"AU17: "+str(AU17_intensity), (5, 25), font, 0.85, (255, 255, 255), 2)
154 |
155 | if images is None:
156 | images = np.expand_dims(img_ori,axis=0)
157 | else:
158 | images = np.concatenate((images, np.expand_dims(img_ori,axis=0)))
159 |
160 | if maps_AU6 is None:
161 | maps_AU6 = np.expand_dims(map_AU6,axis=0)
162 | else:
163 | maps_AU6 = np.concatenate((maps_AU6, np.expand_dims(map_AU6,axis=0)))
164 | if maps_AU10 is None:
165 | maps_AU10 = np.expand_dims(map_AU10,axis=0)
166 | else:
167 | maps_AU10 = np.concatenate((maps_AU10, np.expand_dims(map_AU10,axis=0)))
168 | if maps_AU12 is None:
169 | maps_AU12 = np.expand_dims(map_AU12,axis=0)
170 | else:
171 | maps_AU12 = np.concatenate((maps_AU12, np.expand_dims(map_AU12,axis=0)))
172 | if maps_AU14 is None:
173 | maps_AU14 = np.expand_dims(map_AU14,axis=0)
174 | else:
175 | maps_AU14 = np.concatenate((maps_AU14, np.expand_dims(map_AU14,axis=0)))
176 | if maps_AU17 is None:
177 | maps_AU17 = np.expand_dims(map_AU17,axis=0)
178 | else:
179 | maps_AU17 = np.concatenate((maps_AU17, np.expand_dims(map_AU17,axis=0)))
180 |
181 | # Save the visualized AU heatmaps in path "./visualize/"
182 | if not os.path.exists('./visualize/'):
183 | os.makedirs('./visualize/')
184 |
185 | save_AU6 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU6/255.0).permute(0,3,1,2),scale_factor=0.5)
186 | save_image(save_AU6, './visualize/Subject{}_AU06.png'.format(count))
187 | save_AU10 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU10/255.0).permute(0,3,1,2),scale_factor=0.5)
188 | save_image(save_AU10, './visualize/Subject{}_AU10.png'.format(count))
189 | save_AU12 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU12/255.0).permute(0,3,1,2),scale_factor=0.5)
190 | save_image(save_AU12, './visualize/Subject{}_AU12.png'.format(count))
191 | save_AU14 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU14/255.0).permute(0,3,1,2),scale_factor=0.5)
192 | save_image(save_AU14, './visualize/Subject{}_AU14.png'.format(count))
193 | save_AU17 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU17/255.0).permute(0,3,1,2),scale_factor=0.5)
194 | save_image(save_AU17, './visualize/Subject{}_AU17.png'.format(count))
195 | count += 1
196 |
197 | y_labels = [AU06_intensity, AU10_intensity, AU12_intensity, AU14_intensity, AU17_intensity]
198 | all_intensities.append(y_labels)
199 | plt.plot(x_labels, y_labels, label = "Subject"+str(count-1))
200 | plt.ylabel("FAU intensity")
201 | plt.title("Intensity variation across FAUs")
202 | plt.legend()
203 | plt.show()
204 | print("FAU17 intensities for all subjects: ", au17_all_subjects)
205 | all_intensities = np.array(all_intensities)
206 | plot(all_intensities, x_labels)
207 |
208 | return np.concatenate(preds)
209 |
210 | def plot(all_intensities, x_labels):
211 |
212 | x = ["Subject0", "Subject1", "Subject2", "Subject3", "Subject4"]
213 | for i in range(5):
214 | plt.plot(x, all_intensities[:,i], label = x_labels[i])
215 | plt.ylabel("FAU intensity")
216 | plt.title("FAU intensity variation across subjects")
217 | plt.legend()
218 | plt.show()
219 |
220 | return
221 |
222 |
223 | def test_epoch( dataset_test, model_path,size, npoints):
224 | net = loadnet(npoints,model_path)
225 | OUT = OutIntensity().to('cpu')
226 | # Load data
227 | database = MyDatasets(size=size,database=dataset_test)
228 | dbloader = DataLoader(database, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
229 | pred = predict(dbloader,OUT,net)
230 |
231 | def main():
232 | global args
233 | args = parser.parse_args()
234 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
235 | test_epoch(dataset_test=args.dataset_test,model_path=args.model_path,size=args.size,npoints=args.K)
236 |
237 | if __name__ == '__main__':
238 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch, numpy as np
2 | from load_data import MyDatasets
3 | from utils import *
4 | from tqdm import tqdm
5 | from network import ResNet
6 | from torch.utils.data import Dataset, DataLoader
7 | import os, pickle
8 | import time
9 | from torchvision.utils import save_image
10 | import argparse
11 | from PIL import Image
12 | import matplotlib.pyplot as plt
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--K', default=10, help='Number of AU positions')#24,10
16 | parser.add_argument('--dataset', default='BP4D', type=str, help='database')#BP4D,DISFA
17 | parser.add_argument('--dataset_test', default='train', type=str)#BP4D-val, DISFA-val
18 | parser.add_argument('--model_path', type=str,default='./model/model.pth', help='model path')
19 | parser.add_argument('--cuda', default='5', type=str, help='cuda')
20 | parser.add_argument('--size', default=256, help='Image size')
21 | parser.add_argument('--lr', default=1e-5, help='Learning rate')
22 | parser.add_argument('--epochs', default=10, help='Epochs')
23 |
24 | def get_true_map(heatmap):
25 | true_map = heatmap.detach() / 256 #10x64x64
26 | label = torch.zeros(true_map.shape) #10x64x64
27 | for j in range(0, true_map.shape[0]):
28 | temp = cv2.applyColorMap((np.float32(true_map[j,:,:])).astype(np.uint8), cv2.COLORMAP_JET) #64x64
29 | label[j, :, :] = torch.FloatTensor(temp).mean(-1).unsqueeze(0) #1x64x64
30 | return label #10x64x64
31 |
32 | def loadnet(npoints=10,path_to_model=None):
33 | # Load the trained model.
34 | net = ResNet(num_maps=npoints)
35 | checkpoint = torch.load(path_to_model, map_location='cpu')
36 | checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()}
37 | net.load_state_dict(checkpoint,strict=False)
38 | return net.to('cpu')
39 |
40 | def plot_loss(loss):
41 | xax = np.arange(1,len(loss)+1)
42 | plt.plot(xax, loss)
43 | plt.xlabel("epochs")
44 | plt.ylabel("Loss")
45 | plt.title("Loss vs Number of epochs")
46 | plt.xticks(np.arange(1,len(loss)+1))
47 | plt.show()
48 | return
49 |
50 |
51 |
52 | def train(loader,OUT,net, sample_len):
53 | lossy = []
54 | mean_loss = torch.zeros((sample_len,1))
55 | opt = torch.optim.Adam(net.parameters(), lr=args.lr)
56 | layer = 0
57 | for child in net.children():
58 | layer += 1
59 | if layer < 9:
60 | for param in child.parameters():
61 | param.requires_grad = True
62 | net.train()
63 | for i in tqdm(range(args.epochs)):
64 | print('\tEpochs: ', i,'/', args.epochs,'\tTime:', time.time() - start_time, 'sec')
65 | for idx, sample in enumerate(loader):
66 | opt.zero_grad()
67 | img = sample['Im']
68 | heatmap = net(img).squeeze(0)
69 | true_label = get_true_map(heatmap)
70 | loss = (heatmap - true_label).pow(2).mean()
71 | loss.backward()
72 | opt.step()
73 | mean_loss[idx] = loss.item()
74 | print('Loss- ', (torch.mean(mean_loss)).item())
75 | lossy.append(torch.mean(mean_loss).item())
76 | plot_loss(lossy)
77 | torch.save(net.state_dict(), './model/tuned_model.pth')
78 |
79 | def test_epoch( dataset_test, model_path,size, npoints):
80 | net = loadnet(npoints,model_path)
81 | OUT = OutIntensity().to('cpu')
82 | # Load data
83 | database = MyDatasets(size=size,database=dataset_test)
84 | dbloader = DataLoader(database, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
85 | train(dbloader,OUT,net, len(database))
86 |
87 | def main():
88 | global args
89 | global start_time
90 | start_time = time.time()
91 | args = parser.parse_args()
92 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
93 | test_epoch(dataset_test=args.dataset_test,model_path=args.model_path,size=args.size,npoints=args.K)
94 |
95 | if __name__ == '__main__':
96 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os, sys, torch, math
2 | import cv2
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 |
7 | class OutIntensity(torch.nn.Module):
8 | """Infer AU intensity from a heatmap: :(x, y) = argmax H """
9 | def __init__(self):
10 | super(OutIntensity,self).__init__()
11 |
12 | def forward(self,x):
13 | batch_size = x.shape[0]
14 | num_points = x.shape[1]
15 | width = x.shape[2]
16 | x_ = x.to('cpu').detach().numpy().astype(np.float32).copy()
17 | heatmaps_reshaped = x_.reshape((batch_size, num_points, -1))
18 | intensity = heatmaps_reshaped.max(axis=2)
19 | return intensity/255.0
20 |
--------------------------------------------------------------------------------