├── README.md ├── img ├── overview.pdf ├── overview.png └── temporal.gif └── train ├── model.py ├── rank_dataset.py └── triplet_rank_train.py /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Learning for Facial Action Unit Recognition through Temporal Consistency (BMVC2020 Accepted) 2 | 3 | This repository contains PyTorch implementation of [Self-Supervised Learning for Facial Action Unit Recognition through Temporal Consistency](https://www.bmvc2020-conference.com/assets/papers/0861.pdf) 4 | 5 | ![Image of Overview](https://github.com/intelligent-human-perception-laboratory/temporal-consistency/blob/master/img/overview.png) 6 | 7 | Proposed parallel encoders network takes a sequence of frames extracted from a 8 | video. The anchor frame is selected at time t, the sibling frame at t + 1, and the following 9 | frames at equal intervals from t +1+k to t +1+Nk. All input frames are fed to ResNet-18 10 | encoders with shared weights, followed by a fully-connected layer to generate 256d 11 | embeddings. L2-norm is applied on output embeddings. We then compute triplet losses for 12 | adjacent frame pairs along with the fixed anchor frame. In each adjacent pair, the preceding 13 | frame is the positive sample and the following frame is the negative sample. Finally, all 14 | triplet losses are added to form the ranking triplet loss. 15 | 16 | ## Overview Video 17 | [![](http://img.youtube.com/vi/B4AZU4gsK7o/0.jpg)](http://www.youtube.com/watch?v=B4AZU4gsK7o "") 18 | 19 | ## Temporal Consistency 20 | ![](https://github.com/intelligent-human-perception-laboratory/temporal-consistency/blob/master/img/temporal.gif) 21 | 22 | ## Pre-trained Weights 23 | Pretrained weights could be downloaded from [here](https://drive.google.com/file/d/10BUSkTADsSc8W5jBX59-vtoHXcH0q9o9/view?usp=sharing) 24 | -------------------------------------------------------------------------------- /img/overview.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ihp-lab/temporal-consistency/c18a3860ea5f04da9b1c319482369802e746cf3f/img/overview.pdf -------------------------------------------------------------------------------- /img/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ihp-lab/temporal-consistency/c18a3860ea5f04da9b1c319482369802e746cf3f/img/overview.png -------------------------------------------------------------------------------- /img/temporal.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ihp-lab/temporal-consistency/c18a3860ea5f04da9b1c319482369802e746cf3f/img/temporal.gif -------------------------------------------------------------------------------- /train/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as ops 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | import torch 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | import torch.optim as optim 11 | from torchvision import models 12 | 13 | class resnet18_encoder(nn.Module): 14 | def __init__(self): 15 | super(resnet18_encoder, self).__init__() 16 | resnet18 = models.resnet18(pretrained=True) 17 | resnet18_layers = list(resnet18.children())[:-1] 18 | self.resnet18 = nn.Sequential(*resnet18_layers) 19 | self.fc1 = nn.Linear(512, 256) 20 | 21 | def forward(self, x): 22 | output = self.resnet18(x) 23 | output = torch.flatten(output, 1) 24 | output = self.fc1(output) 25 | output = F.normalize(output, p=2, dim=1) 26 | return output 27 | 28 | class densenet121_encoder(nn.Module): 29 | def __init__(self): 30 | super(densenet121_encoder, self).__init__() 31 | densenet = models.densenet121(pretrained=True) 32 | densenet_layers = list(densenet.children())[:-1] 33 | self.densenet = nn.Sequential(*densenet_layers) 34 | self.fc1 = nn.Linear(1024, 256) 35 | 36 | def forward(self, x): 37 | output = self.densenet(x) 38 | output = F.relu(output, inplace=True) 39 | output = F.adaptive_avg_pool2d(output,(1,1)) 40 | output = torch.flatten(output, 1) 41 | output = self.fc1(output) 42 | output = F.normalize(output, p=2, dim=1) 43 | return output 44 | 45 | class mobilenet_encoder(nn.Module): 46 | def __init__(self): 47 | super(mobilenet_encoder, self).__init__() 48 | mobilenet = models.mobilenet_v2(pretrained=True) 49 | mobilenet_layers = list(mobilenet.children())[:-1] 50 | self.mobilenet = nn.Sequential(*mobilenet_layers) 51 | self.fc1 = nn.Linear(1280, 256) 52 | 53 | def forward(self, x): 54 | output = self.mobilenet(x) 55 | output = nn.functional.adaptive_avg_pool2d(output, 1).reshape(output.shape[0], -1) 56 | output = torch.flatten(output, 1) 57 | output = self.fc1(output) 58 | output = F.normalize(output, p=2, dim=1) 59 | return output 60 | 61 | 62 | -------------------------------------------------------------------------------- /train/rank_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from torchvision import transforms 8 | from torch.utils.data.dataset import Dataset # For custom datasets 9 | from torchvision.transforms import ToTensor, Scale, Compose, Pad, RandomHorizontalFlip, CenterCrop, RandomCrop, Scale, ToPILImage 10 | from torchvision.transforms import ToPILImage 11 | import matplotlib.pyplot as plt 12 | import os, sys 13 | import torch 14 | from tqdm import tqdm 15 | 16 | image_base_folder = '../vox2_crop_fps25' 17 | 18 | class CustomDatasetFromImages(Dataset): 19 | def __init__(self, transformations, spacing): 20 | """ 21 | Args: 22 | csv_path (string): path to csv file 23 | img_path (string): path to the folder where images are 24 | transform: pytorch transforms for transforms and tensor conversion 25 | """ 26 | 27 | self.seed = np.random.seed(567) 28 | self.transform = transformations 29 | # Transforms 30 | self.to_tensor = transforms.ToTensor() 31 | # Read the csv file 32 | self.video_label = sorted(os.listdir(image_base_folder)) 33 | frame_index = [] 34 | self.video_label_suff = [] 35 | 36 | length = spacing * 10 + 1 37 | 38 | for video in tqdm(self.video_label): 39 | path = image_base_folder + '/' + video 40 | frames = sorted(os.listdir(path), key=lambda x: int(x[:-4])) 41 | 42 | if len(frames) - length> 0: 43 | index = np.random.choice(range(len(frames) - length), size=1) 44 | frame_index.append(index) 45 | self.video_label_suff.append(video) 46 | 47 | 48 | self.frame_index = frame_index 49 | # Calculate len 50 | self.data_len = len(self.video_label_suff) 51 | self.spacing = spacing 52 | 53 | 54 | def __getitem__(self, index): 55 | # Get image name from the pandas df 56 | video_off = int(index % 1) 57 | video_base_index = int((index - video_off) / 1) 58 | anchor_index = self.frame_index[video_base_index][video_off] 59 | pos_index = anchor_index + 1 60 | 61 | path = image_base_folder + '/' + self.video_label_suff[video_base_index] 62 | frames = sorted(os.listdir(path), key=lambda x: int(x[:-4])) 63 | random_frame = frames[anchor_index] 64 | close_frame = frames[pos_index] 65 | 66 | far_frame_name_list = [] 67 | 68 | for i in range(1,11): 69 | neg_index = i * self.spacing + anchor_index + 1 70 | far_frame = frames[neg_index] 71 | far_name = path + '/' + far_frame 72 | far_frame_name_list.append(far_name) 73 | 74 | source_name = path + '/' + random_frame 75 | close_name = path + '/' + close_frame 76 | 77 | # Open image 78 | try: 79 | s_img = Image.open(source_name) 80 | c_img = Image.open(close_name) 81 | 82 | 83 | f_imgs = [] 84 | for far in far_frame_name_list: 85 | f_img = Image.open(far) 86 | f_imgs.append(f_img) 87 | 88 | except FileNotFoundError: 89 | print("sample missing use first") 90 | return self.__getitem__(0) 91 | 92 | imgs = [0] * 12 93 | img_s = self.transform(s_img) 94 | imgs[0] = img_s 95 | 96 | img_c = self.transform(c_img) 97 | imgs[1] = img_c 98 | 99 | for i in range(10): 100 | img_f = self.transform(f_imgs[i]) 101 | imgs[2+i] = img_f 102 | single_image_label = self.video_label_suff[video_base_index] 103 | 104 | return (imgs, single_image_label) 105 | 106 | def __len__(self): 107 | return self.data_len 108 | -------------------------------------------------------------------------------- /train/triplet_rank_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as ops 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | import torch 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | import torch.optim as optim 11 | from torchvision import models 12 | from rank_dataset import CustomDatasetFromImages 13 | from torchvision.transforms import ToTensor, Scale, Compose, Pad, RandomHorizontalFlip, CenterCrop, RandomCrop, Scale, ToPILImage 14 | from torchvision.transforms import ToPILImage 15 | import sys 16 | from tqdm import tqdm 17 | from model import resnet18_encoder, densenet121_encoder, mobilenet_encoder 18 | import argparse 19 | 20 | os.environ['CUDA_VISIBLE_DEVICES']='0,1' 21 | 22 | 23 | arguments = argparse.ArgumentParser() 24 | arguments.add_argument('--lr', type=float, default=0.001) 25 | arguments.add_argument('--momentum', type=float, default=0.9) 26 | arguments.add_argument('--num_workers', type=int, default=12) 27 | arguments.add_argument('--batch_size', type=int, default=48) 28 | arguments.add_argument('--num_epoch', type=int, default=50) 29 | arguments.add_argument('--spacing_size', type=int, default=1) 30 | arguments.add_argument('--random_seed', type=int, default=123) 31 | args = arguments.parse_args() 32 | 33 | 34 | batch_size = args.batch_size 35 | epoch = args.num_epoch 36 | save_path = 'model_save/' 37 | 38 | #normalize for ImageNet 39 | normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) 41 | 42 | crop = 200 43 | rng = np.random.RandomState(args.random_seed) 44 | precrop = crop + 24 45 | crop = rng.randint(crop, precrop) 46 | transformations = Compose([ 47 | Scale((256,256)), 48 | Pad((24,24,24,24)), 49 | CenterCrop(precrop), 50 | RandomCrop(crop), 51 | Scale((256,256)), 52 | ToTensor(), 53 | normalize]) 54 | 55 | #define a batch-wise l2 loss 56 | def criterion_l2(input_f, target_f): 57 | # return a per batch l2 loss 58 | res = (input_f - target_f) 59 | res = res * res 60 | return res.sum(dim=2) 61 | 62 | def criterion_l2_2(input_f, target_f): 63 | # return a per batch l2 loss 64 | res = (input_f - target_f) 65 | res = res * res 66 | return res.sum(dim=1) 67 | 68 | def criterion_cos(input_f, target_f): 69 | cos = nn.CosineSimilarity(dim=2, eps=1e-6) 70 | return cos(input_f, target_f) 71 | 72 | def criterion_cos2(input_f, target_f): 73 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 74 | return cos(input_f, target_f) 75 | 76 | def tuplet_loss(anchor, close, sequence): 77 | delta = 3e-2 * torch.ones(anchor.size(0), device='cuda') 78 | # N x 10 x 256 79 | anchors = torch.unsqueeze(anchor, dim=1) 80 | positives = torch.unsqueeze(close, dim=1) 81 | 82 | close_distance = criterion_l2(anchors, positives).view(-1) 83 | far_distance = criterion_l2(anchors.expand(sequence.size(0), 10, 256), sequence) 84 | loss = torch.max(torch.zeros(anchor.size(0), device='cuda'), (close_distance - far_distance[:,0] + delta)) 85 | 86 | for i in range(1,10): 87 | loss += torch.max(torch.zeros(anchor.size(0), device='cuda'), (far_distance[:,i-1] - far_distance[:,i] + delta)) 88 | 89 | return loss.mean() 90 | 91 | 92 | torch.manual_seed(args.random_seed) 93 | combine_sets = CustomDatasetFromImages(transformations, spacing=args.spacing_size) 94 | train_size = int(0.8 * len(combine_sets)) 95 | test_size = len(combine_sets) - train_size 96 | train_dataset, test_dataset = torch.utils.data.random_split(combine_sets, [train_size, test_size]) 97 | 98 | train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset, 99 | batch_size=batch_size, 100 | shuffle=True, num_workers = args.num_workers) 101 | 102 | test_dataset_loader = torch.utils.data.DataLoader(dataset=test_dataset, 103 | batch_size=batch_size, 104 | shuffle=True, num_workers = args.num_workers) 105 | 106 | 107 | model = resnet18_encoder() 108 | model = torch.nn.DataParallel(model, device_ids=[0,1]) 109 | model.cuda() 110 | 111 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 112 | 113 | 114 | train_loss_list = [] 115 | test_acc_list = [] 116 | test_loss_list = [] 117 | 118 | 119 | def checkpoint(model, my_save_path, epoch): 120 | print ("save model, current epoch:{0}".format(epoch)) 121 | save_epoch = epoch + 0 122 | final_save_path = my_save_path + '_' + str(save_epoch) + '.pth' 123 | checkpoint_state = { 124 | 'state_dict' : model.state_dict(), 125 | 'optimizer' : optimizer.state_dict() 126 | } 127 | 128 | torch.save(checkpoint_state, final_save_path) 129 | 130 | 131 | def train_model(model, epoches): 132 | total_loss = 0 133 | for i, (batch, label) in enumerate(tqdm(train_dataset_loader)): 134 | optimizer.zero_grad() 135 | input_images = Variable(batch[0]).cuda() 136 | close_images = Variable(batch[1]).cuda() 137 | far_images = Variable(torch.stack(batch[2:], dim=1)).cuda().view(-1, 3, 256, 256) 138 | 139 | input_emb = model(input_images) 140 | close_emb = model(close_images) 141 | far_emb = model(far_images).view(input_emb.size(0),-1, 256) 142 | 143 | loss = tuplet_loss(input_emb, close_emb, far_emb) 144 | loss.backward() 145 | optimizer.step() 146 | total_loss += loss.cpu().detach().numpy() 147 | 148 | total_loss = total_loss * 1.0 / len(train_dataset_loader) 149 | train_loss_list.append(total_loss) 150 | print("train loss at the epoch %d is %f"%(epoches, total_loss)) 151 | 152 | 153 | 154 | def test_model(model, epoches): 155 | with torch.no_grad(): 156 | total_loss = 0 157 | sep_acc = np.zeros(10) 158 | for i, (batch, label) in enumerate(tqdm(test_dataset_loader)): 159 | input_images = Variable(batch[0]).cuda() 160 | close_images = Variable(batch[1]).cuda() 161 | far_images = Variable(torch.stack(batch[2:], dim=1)).cuda().view(-1, 3, 256, 256) 162 | #torchvision.utils.save_image(far_images[0:10,:,:,:].data, 'test.png') 163 | 164 | input_emb = model(input_images) 165 | close_emb = model(close_images) 166 | far_emb = model(far_images).view(input_emb.size(0),-1, 256) 167 | 168 | loss = tuplet_loss(input_emb, close_emb, far_emb) 169 | total_loss += loss.cpu().detach().numpy() 170 | 171 | far_emb = far_emb.permute(1,0,2) 172 | close_dist = criterion_l2_2(input_emb, close_emb) 173 | far_dist = criterion_l2_2(input_emb, far_emb[0,:, :]) 174 | diff = -close_dist + far_dist 175 | diff = diff.cpu().numpy() 176 | correct_index = np.where(diff > 0.0)[0] 177 | sep_acc[0] += correct_index.shape[0] * 1.0 / batch_size 178 | 179 | for i in range(1, 10): 180 | far_dist = criterion_l2_2(input_emb, far_emb[i,:,:]) 181 | close_dist = criterion_l2_2(input_emb, far_emb[i-1,:,:]) 182 | diff = -close_dist + far_dist 183 | diff = diff.cpu().numpy() 184 | correct_index2 = np.where(diff > 0.0)[0] 185 | sep_acc[i] += correct_index2.shape[0] * 1.0 / batch_size 186 | 187 | sep_acc /= len(test_dataset_loader) 188 | total_loss = total_loss * 1.0 / len(test_dataset_loader) 189 | total_acc = np.mean(sep_acc) 190 | 191 | print("test accuracy at the is %f"%(total_acc)) 192 | print("test loss at the epoch %d is %f"%(epoches, total_loss)) 193 | print("test sep acc at the epoc", sep_acc) 194 | 195 | test_acc_list.append(total_acc) 196 | test_loss_list.append(total_loss) 197 | 198 | 199 | 200 | if __name__ == '__main__': 201 | # check first 202 | model.eval() 203 | test_model(model, -1) 204 | 205 | for i in range(epoch): 206 | model.train() 207 | train_model(model, i) 208 | 209 | model.eval() 210 | test_model(model, i) 211 | 212 | checkpoint(model,save_path, i) 213 | 214 | print(test_acc_list) 215 | print(test_loss_list) 216 | test_acc_list = np.array(test_acc_list) 217 | test_loss_list = np.array(test_loss_list) 218 | 219 | np.savez_compressed(save_path + 'loss_logger', testa=test_acc_list, testl=test_loss_list) 220 | 221 | --------------------------------------------------------------------------------