├── .gitigore ├── LICENSE ├── README.md ├── data ├── .DS_Store ├── F014.jpg ├── F016.jpg ├── M014.jpg └── examples.txt ├── load_data.py ├── network.py ├── run_demo.py └── utils.py /.gitigore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-FAU 2 | 3 | The Pytorch implementation of **Facial Action Unit Intensity Estimation**. 4 | 5 | ## Environment 6 | 7 | - Ubuntu 18.04.4 8 | - Python 3.7 9 | - PyTorch 1.3.0 10 | - Torchvision 11 | - Python-OpenCV 12 | 13 | ***Datasets*** 14 | 15 | For data preparation, please make a request for the [BP4D database](http://www.cs.binghamton.edu/~lijun/Research/3DFE/3DFE_Analysis.html) and the [DISFA database](http://mohammadmahoor.com/disfa/). 16 | 17 | ***Usage*** 18 | 19 | The pre-trained model can be obtained from the [link](https://drive.google.com/file/d/15cJtFEkrOrbt5FfZOxnWulQhaOKfaLtN/view?usp=sharing). Please download it under your own path. You can change default path by modifying `--model_path`. 20 | 21 | - `run_demo.py`: visualizes the predicted AU heatmaps and intensities for the example images (`data/*.jpg`). 22 | 23 | ```python 24 | cd Pytorch-FAU/ 25 | python run_demo.py 26 | ``` 27 | 28 | The full code will be available soon. 29 | 30 | ## Citation 31 | 32 | @inproceedings{fan2020fau, 33 | title = {Facial Action Unit Intensity Estimation via Semantic 34 | Correspondence Learning with Dynamic Graph Convolution}, 35 | author = {Fan, Yingruo and Lam, Jacqueline and Li, Victor}, 36 | booktitle = {Thirty-Fourth AAAI Conference on Artificial Intelligence}, 37 | year={2020} 38 | } 39 | 40 | ## Acknowledgement 41 | 42 | The code partially refers open-sourced [Action-Units-Heatmaps](https://github.com/ESanchezLozano/Action-Units-Heatmaps). Thanks to them for the great work. 43 | 44 | 45 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmdydxr/Pytorch-FAU/51e6c5e7ebe8be1514bfc2649344ff6186f47c5f/data/.DS_Store -------------------------------------------------------------------------------- /data/F014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmdydxr/Pytorch-FAU/51e6c5e7ebe8be1514bfc2649344ff6186f47c5f/data/F014.jpg -------------------------------------------------------------------------------- /data/F016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmdydxr/Pytorch-FAU/51e6c5e7ebe8be1514bfc2649344ff6186f47c5f/data/F016.jpg -------------------------------------------------------------------------------- /data/M014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmdydxr/Pytorch-FAU/51e6c5e7ebe8be1514bfc2649344ff6186f47c5f/data/M014.jpg -------------------------------------------------------------------------------- /data/examples.txt: -------------------------------------------------------------------------------- 1 | ./data/F014.jpg 2 | ./data/M014.jpg 3 | ./data/F016.jpg 4 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load the BP4D or the DISFA dataset. 3 | """ 4 | import os 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import inspect 9 | from torch.utils.data import Dataset 10 | 11 | class MyDatasets(Dataset): 12 | def __init__(self, sigma=2, size=256, heatmap=32,AU_positions=10, database=''): 13 | if database == 'BP4D': 14 | txt_file = open('/data0/evelyn/BP4D_AUCoding/train_BP4D.txt','r') 15 | if database == 'DISFA': 16 | txt_file = open('/data0/evelyn/DISFA/train_DISFA.txt','r') 17 | if database == 'BP4D-val': 18 | txt_file = open('/data0/evelyn/BP4D_AUCoding/val_BP4D.txt','r') 19 | if database == 'DISFA-val': 20 | txt_file = open('/data0/evelyn/DISFA/test_DISFA.txt','r') 21 | if database == 'demo': 22 | txt_file = open('./data/examples.txt','r') 23 | lines = txt_file.readlines()[0::] 24 | names = [l.split()[0] for l in lines] 25 | coords = [l.split()[1::] for l in lines] 26 | self.database = database 27 | self.data = dict(zip(names,coords)) 28 | self.imgs = list(set(names)) 29 | self.len = len(self.imgs) 30 | 31 | def generate_target(self, points, intensity): 32 | target = np.zeros((self.AU_positions,self.heatmap,self.heatmap),dtype=np.float32) 33 | gs_range = self.sigma * 15 34 | for point_id in range(self.AU_positions): 35 | feat_stride = self.size / self.heatmap 36 | mu_x = int(points[point_id][0] / feat_stride + 0.5) 37 | mu_y = int(points[point_id][1] / feat_stride + 0.5) 38 | ul = [int(mu_x - gs_range), int(mu_y - gs_range)] 39 | br = [int(mu_x + gs_range + 1), int(mu_y + gs_range + 1)] 40 | x = np.arange(0, 2*gs_range+1, 1, np.float32) 41 | y = x[:, np.newaxis] 42 | x_center = y_center = (2*gs_range+1) // 2 43 | g = np.exp(- ((x - x_center) ** 2 + (y - y_center) ** 2) / (2 * self.sigma ** 2)) 44 | g_x = max(0, -ul[0]), min(br[0], self.heatmap) - ul[0] 45 | g_y = max(0, -ul[1]), min(br[1], self.heatmap) - ul[1] 46 | img_x = max(0, ul[0]), min(br[0], self.heatmap) 47 | img_y = max(0, ul[1]), min(br[1], self.heatmap) 48 | target[point_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = intensity[point_id]*g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 49 | return target*255.0 50 | 51 | def fetch(self,index): 52 | path_to_img = self.imgs[index] 53 | image = cv2.cvtColor(cv2.imread(path_to_img), cv2.COLOR_BGR2RGB) 54 | if self.database == 'demo': 55 | return image, [0], [0] 56 | AUs = self.data[self.imgs[index]] 57 | AUs = np.float32(self.data[self.imgs[index]]).reshape(-1,3) 58 | AU_coords = AUs[:,:2] 59 | AU_intensity = AUs[:,2] 60 | return image, AU_coords, AU_intensity 61 | 62 | def __getitem__(self,index): 63 | image, AU_coords, AU_intensity = self.fetch(index) 64 | nimg = len(image) 65 | sample = dict.fromkeys(['Im'], None) 66 | out = dict.fromkeys(['image','points']) 67 | image_np = torch.from_numpy((image/255.0).swapaxes(2,1).swapaxes(1,0)) 68 | out['image'] = image_np.type_as(torch.FloatTensor()) 69 | out['AU_coords'] = np.floor(AU_coords) 70 | if not self.database == 'demo': 71 | target = self.generate_target(out['AU_coords'], AU_intensity) 72 | target = torch.from_numpy(target).type_as(torch.FloatTensor()) 73 | sample['target'] = target 74 | sample['pts'] = out['AU_coords'] 75 | sample['intensity'] = AU_intensity 76 | sample['Im'] = out['image'] 77 | return sample 78 | 79 | def __len__(self): 80 | return len(self.imgs) 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | # ------------------------------------------------------------------------------ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | import logging 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | 15 | BN_MOMENTUM = 0.1 16 | logger = logging.getLogger(__name__) 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | return nn.Conv2d( 20 | in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False 22 | ) 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 64 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 65 | bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 67 | momentum=BN_MOMENTUM) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | return out 92 | 93 | class ResNet(nn.Module): 94 | #Res-18 BasicBlock,[2, 2, 2, 2]; Res-50 Bottleneck,[3,4,6,3] 95 | def __init__(self, block=Bottleneck, num_maps=10, layers=[3,4,6,3]): 96 | self.inplanes = 64 97 | self.deconv_with_bias = False 98 | 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | 110 | # used for deconv layers 111 | self.deconv_layers_1 = self._make_deconv_layer(0, [256, 256, 256],[4, 4, 4]) 112 | self.deconv_layers_2 = self._make_deconv_layer(1, [256, 256, 256],[4, 4, 4]) 113 | self.deconv_layers_3 = self._make_deconv_layer(2, [256, 256, 256],[4, 4, 4]) 114 | 115 | self.final_layer = nn.Conv2d( 116 | in_channels=256, 117 | out_channels=num_maps, 118 | kernel_size=1, 119 | stride=1, 120 | padding=0 121 | ) 122 | 123 | self._initialize_weights() 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def _get_deconv_cfg(self, deconv_kernel, index): 143 | if deconv_kernel == 4: 144 | padding = 1 145 | output_padding = 0 146 | elif deconv_kernel == 3: 147 | padding = 1 148 | output_padding = 1 149 | elif deconv_kernel == 2: 150 | padding = 0 151 | output_padding = 0 152 | 153 | return deconv_kernel, padding, output_padding 154 | 155 | def _make_deconv_layer(self, i , num_filters, num_kernels): 156 | 157 | layers = [] 158 | 159 | kernel, padding, output_padding = \ 160 | self._get_deconv_cfg(num_kernels[i], i) 161 | 162 | planes = num_filters[i] 163 | layers.append( 164 | nn.ConvTranspose2d( 165 | in_channels=self.inplanes, 166 | out_channels=planes, 167 | kernel_size=kernel, 168 | stride=2, 169 | padding=padding, 170 | output_padding=output_padding, 171 | bias=self.deconv_with_bias)) 172 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) 173 | layers.append(nn.ReLU(inplace=True)) 174 | self.inplanes = planes 175 | 176 | return nn.Sequential(*layers) 177 | 178 | def forward(self, x): 179 | x = self.conv1(x) 180 | x = self.bn1(x) 181 | x = self.relu(x) 182 | x = self.maxpool(x) 183 | 184 | x = self.layer1(x) 185 | x = self.layer2(x) 186 | x = self.layer3(x) 187 | x = self.layer4(x) 188 | 189 | x = self.deconv_layers_1(x) 190 | x = self.deconv_layers_2(x) 191 | x = self.deconv_layers_3(x) 192 | x = self.final_layer(x) 193 | 194 | return x 195 | 196 | def _initialize_weights(self): 197 | for m in self.modules(): 198 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 199 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 200 | m.weight.data.normal_(0, math.sqrt(2. / n)) 201 | if m.bias is not None: 202 | m.bias.data.zero_() 203 | elif isinstance(m, nn.BatchNorm2d): 204 | m.weight.data.fill_(1) 205 | m.bias.data.zero_() 206 | elif isinstance(m, nn.Linear): 207 | n = m.weight.size(1) 208 | m.weight.data.normal_(0, 0.01) 209 | m.bias.data.zero_() 210 | 211 | resnet_spec = { 212 | 18: (BasicBlock, [2, 2, 2, 2]), 213 | 34: (BasicBlock, [3, 4, 6, 3]), 214 | 50: (Bottleneck, [3, 4, 6, 3]), 215 | 101: (Bottleneck, [3, 4, 23, 3]), 216 | 152: (Bottleneck, [3, 8, 36, 3]) 217 | } 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize the predicted AU heatmaps and intensities for the example images (./data/F014.jpg,./data/F016.jpg, ./data/M014.jpg ) 3 | """ 4 | import torch, numpy as np 5 | from load_data import MyDatasets 6 | from utils import * 7 | from network import ResNet 8 | from torch.utils.data import Dataset, DataLoader 9 | import os, pickle 10 | from torchvision.utils import save_image 11 | import argparse 12 | from PIL import Image 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--K', default=10, help='Number of AU positions')#24,10 16 | parser.add_argument('--dataset', default='BP4D', type=str, help='database')#BP4D,DISFA 17 | parser.add_argument('--dataset_test', default='demo', type=str)#BP4D-val, DISFA-val 18 | parser.add_argument('--model_path', type=str,default='./data/model.pth', help='model path') 19 | parser.add_argument('--cuda', default='5', type=str, help='cuda') 20 | parser.add_argument('--size', default=256, help='Image size') 21 | 22 | def loadnet(npoints=10,path_to_model=None): 23 | # Load the trained model. 24 | net = ResNet(num_maps=npoints) 25 | checkpoint = torch.load(path_to_model) 26 | checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()} 27 | net.load_state_dict(checkpoint,strict=False) 28 | return net.to('cuda') 29 | 30 | def predict(loader,OUT,net): 31 | preds = [] 32 | with torch.no_grad(): 33 | count = 0 34 | for sample in loader: 35 | img = sample['Im'] 36 | heatmap = net(img.cuda()) 37 | out = OUT(heatmap) 38 | preds.append(out) 39 | images = None 40 | maps_AU6 = None 41 | maps_AU10 = None 42 | maps_AU12= None 43 | maps_AU14= None 44 | maps_AU17= None 45 | threshold = 0.1 46 | font = cv2.FONT_HERSHEY_SIMPLEX 47 | 48 | for (index,item) in enumerate(sample['Im'].to('cpu').detach()): 49 | img_ori = (255*item.permute(1,2,0).numpy()).astype(np.uint8).copy() 50 | AU_intensities = out[index] 51 | AU06_intensity = round((out[index][0]+out[index][1])/2.0,2) 52 | AU10_intensity = round((out[index][2]+out[index][3])/2.0,2) 53 | AU12_intensity = round((out[index][4]+out[index][5])/2.0,2) 54 | AU14_intensity = round((out[index][6]+out[index][7])/2.0,2) 55 | AU17_intensity = round((out[index][8]+out[index][9])/2.0,2) 56 | """ 57 | Visualization of the predicted AU6 heatmap. 58 | """ 59 | heatmap_AU6_0 = heatmap[index][0].to('cpu').detach() 60 | heatmap_AU6_0[heatmap_AU6_0255*5.0]=255*5.0 62 | heatmap_AU6_0_np = (heatmap_AU6_0.numpy()/5.0).astype(np.uint8).copy() 63 | heatmap_AU6_0_rz = cv2.resize(heatmap_AU6_0_np,(256,256)) 64 | map_AU6_0 = cv2.applyColorMap(heatmap_AU6_0_rz, cv2.COLORMAP_JET) 65 | map_AU6_0=cv2.cvtColor(map_AU6_0, cv2.COLOR_RGB2BGR) 66 | heatmap_AU6_1 = heatmap[index][1].to('cpu').detach() 67 | heatmap_AU6_1[heatmap_AU6_1255*5.0]=255*5.0 69 | heatmap_AU6_1_np = (heatmap_AU6_1.numpy()/5.0).astype(np.uint8).copy() 70 | heatmap_AU6_1_rz = cv2.resize(heatmap_AU6_1_np,(256,256)) 71 | map_AU6_1 = cv2.applyColorMap(heatmap_AU6_1_rz, cv2.COLORMAP_JET) 72 | map_AU6_1=cv2.cvtColor(map_AU6_1, cv2.COLOR_RGB2BGR) 73 | map_AU6 = map_AU6_0*0.5+map_AU6_1*0.5+img_ori*0.5 74 | cv2.putText(map_AU6,"AU6: "+str(AU06_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 75 | """ 76 | Visualization of the predicted AU10 heatmap. 77 | """ 78 | heatmap_AU10_0 = heatmap[index][2].to('cpu').detach() 79 | heatmap_AU10_0[heatmap_AU10_0255*5.0]=255*5.0 81 | heatmap_AU10_0_np = (heatmap_AU10_0.numpy()/5.0).astype(np.uint8).copy() 82 | heatmap_AU10_0_rz = cv2.resize(heatmap_AU10_0_np,(256,256)) 83 | map_AU10_0 = cv2.applyColorMap(heatmap_AU10_0_rz, cv2.COLORMAP_JET) 84 | map_AU10_0=cv2.cvtColor(map_AU10_0, cv2.COLOR_RGB2BGR) 85 | heatmap_AU10_1 = heatmap[index][3].to('cpu').detach() 86 | heatmap_AU10_1[heatmap_AU10_1255*5.0]=255*5.0 88 | heatmap_AU10_1_np = (heatmap_AU10_1.numpy()/5.0).astype(np.uint8).copy() 89 | heatmap_AU10_1_rz = cv2.resize(heatmap_AU10_1_np,(256,256)) 90 | map_AU10_1 = cv2.applyColorMap(heatmap_AU10_1_rz, cv2.COLORMAP_JET) 91 | map_AU10_1=cv2.cvtColor(map_AU10_1, cv2.COLOR_RGB2BGR) 92 | map_AU10 = map_AU10_0*0.5+map_AU10_1*0.5+img_ori*0.5 93 | cv2.putText(map_AU10,"AU10: "+str(AU10_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 94 | """ 95 | Visualization of the predicted AU12 heatmap. 96 | """ 97 | heatmap_AU12_0 = heatmap[index][4].to('cpu').detach() 98 | heatmap_AU12_0[heatmap_AU12_0255*5.0]=255*5.0 100 | heatmap_AU12_0_np = (heatmap_AU12_0.numpy()/5.0).astype(np.uint8).copy() 101 | heatmap_AU12_0_rz = cv2.resize(heatmap_AU12_0_np,(256,256)) 102 | map_AU12_0 = cv2.applyColorMap(heatmap_AU12_0_rz, cv2.COLORMAP_JET) 103 | map_AU12_0=cv2.cvtColor(map_AU12_0, cv2.COLOR_RGB2BGR) 104 | heatmap_AU12_1 = heatmap[index][5].to('cpu').detach() 105 | heatmap_AU12_1[heatmap_AU12_1255*5.0]=255*5.0 107 | heatmap_AU12_1_np = (heatmap_AU12_1.numpy()/5.0).astype(np.uint8).copy() 108 | heatmap_AU12_1_rz = cv2.resize(heatmap_AU12_1_np,(256,256)) 109 | map_AU12_1 = cv2.applyColorMap(heatmap_AU12_1_rz, cv2.COLORMAP_JET) 110 | map_AU12_1=cv2.cvtColor(map_AU12_1, cv2.COLOR_RGB2BGR) 111 | map_AU12 = map_AU12_0*0.5+map_AU12_1*0.5+img_ori*0.5 112 | cv2.putText(map_AU12,"AU12: "+str(AU12_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 113 | """ 114 | Visualization of the predicted AU14 heatmap. 115 | """ 116 | heatmap_AU14_0 = heatmap[index][6].to('cpu').detach() 117 | heatmap_AU14_0[heatmap_AU14_0255*5.0]=255*5.0 119 | heatmap_AU14_0_np = (heatmap_AU14_0.numpy()/5.0).astype(np.uint8).copy() 120 | heatmap_AU14_0_rz = cv2.resize(heatmap_AU14_0_np,(256,256)) 121 | map_AU14_0 = cv2.applyColorMap(heatmap_AU14_0_rz, cv2.COLORMAP_JET) 122 | map_AU14_0=cv2.cvtColor(map_AU14_0, cv2.COLOR_RGB2BGR) 123 | heatmap_AU14_1 = heatmap[index][7].to('cpu').detach() 124 | heatmap_AU14_1[heatmap_AU14_1255*5.0]=255*5.0 126 | heatmap_AU14_1_np = (heatmap_AU14_1.numpy()/5.0).astype(np.uint8).copy() 127 | heatmap_AU14_1_rz = cv2.resize(heatmap_AU14_1_np,(256,256)) 128 | map_AU14_1 = cv2.applyColorMap(heatmap_AU14_1_rz, cv2.COLORMAP_JET) 129 | map_AU14_1=cv2.cvtColor(map_AU14_1, cv2.COLOR_RGB2BGR) 130 | map_AU14 = map_AU14_0*0.5+map_AU14_1*0.5+img_ori*0.5 131 | cv2.putText(map_AU14,"AU14: "+str(AU14_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 132 | """ 133 | Visualization of the predicted AU17 heatmap. 134 | """ 135 | heatmap_AU17_0 = heatmap[index][8].to('cpu').detach() 136 | heatmap_AU17_0[heatmap_AU17_0255*5.0]=255*5.0 138 | heatmap_AU17_0_np = (heatmap_AU17_0.numpy()/5.0).astype(np.uint8).copy() 139 | heatmap_AU17_0_rz = cv2.resize(heatmap_AU17_0_np,(256,256)) 140 | map_AU17_0 = cv2.applyColorMap(heatmap_AU17_0_rz, cv2.COLORMAP_JET) 141 | map_AU17_0=cv2.cvtColor(map_AU17_0, cv2.COLOR_RGB2BGR) 142 | heatmap_AU17_1 = heatmap[index][9].to('cpu').detach() 143 | heatmap_AU17_1[heatmap_AU17_1255*5.0]=255*5.0 145 | heatmap_AU17_1_np = (heatmap_AU17_1.numpy()/5.0).astype(np.uint8).copy() 146 | heatmap_AU17_1_rz = cv2.resize(heatmap_AU17_1_np,(256,256)) 147 | map_AU17_1 = cv2.applyColorMap(heatmap_AU17_1_rz, cv2.COLORMAP_JET) 148 | map_AU17_1=cv2.cvtColor(map_AU17_1, cv2.COLOR_RGB2BGR) 149 | map_AU17 = map_AU17_0*0.5+map_AU17_1*0.5+img_ori*0.5 150 | cv2.putText(map_AU17,"AU17: "+str(AU17_intensity), (5, 25), font, 0.85, (255, 255, 255), 2) 151 | 152 | if images is None: 153 | images = np.expand_dims(img_ori,axis=0) 154 | else: 155 | images = np.concatenate((images, np.expand_dims(img_ori,axis=0))) 156 | 157 | if maps_AU6 is None: 158 | maps_AU6 = np.expand_dims(map_AU6,axis=0) 159 | else: 160 | maps_AU6 = np.concatenate((maps_AU6, np.expand_dims(map_AU6,axis=0))) 161 | if maps_AU10 is None: 162 | maps_AU10 = np.expand_dims(map_AU10,axis=0) 163 | else: 164 | maps_AU10 = np.concatenate((maps_AU10, np.expand_dims(map_AU10,axis=0))) 165 | if maps_AU12 is None: 166 | maps_AU12 = np.expand_dims(map_AU12,axis=0) 167 | else: 168 | maps_AU12 = np.concatenate((maps_AU12, np.expand_dims(map_AU12,axis=0))) 169 | if maps_AU14 is None: 170 | maps_AU14 = np.expand_dims(map_AU14,axis=0) 171 | else: 172 | maps_AU14 = np.concatenate((maps_AU14, np.expand_dims(map_AU14,axis=0))) 173 | if maps_AU17 is None: 174 | maps_AU17 = np.expand_dims(map_AU17,axis=0) 175 | else: 176 | maps_AU17 = np.concatenate((maps_AU17, np.expand_dims(map_AU17,axis=0))) 177 | 178 | # Save the visualized AU heatmaps in path "./visualize/" 179 | if not os.path.exists('./visualize/'): 180 | os.makedirs('./visualize/') 181 | 182 | save_AU6 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU6/255.0).permute(0,3,1,2),scale_factor=0.5) 183 | save_image(save_AU6, './visualize/Subject{}_AU06.png'.format(count)) 184 | save_AU10 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU10/255.0).permute(0,3,1,2),scale_factor=0.5) 185 | save_image(save_AU10, './visualize/Subject{}_AU10.png'.format(count)) 186 | save_AU12 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU12/255.0).permute(0,3,1,2),scale_factor=0.5) 187 | save_image(save_AU12, './visualize/Subject{}_AU12.png'.format(count)) 188 | save_AU14 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU14/255.0).permute(0,3,1,2),scale_factor=0.5) 189 | save_image(save_AU14, './visualize/Subject{}_AU14.png'.format(count)) 190 | save_AU17 = torch.nn.functional.interpolate(torch.from_numpy(maps_AU17/255.0).permute(0,3,1,2),scale_factor=0.5) 191 | save_image(save_AU17, './visualize/Subject{}_AU17.png'.format(count)) 192 | count += 1 193 | 194 | return np.concatenate(preds) 195 | 196 | def test_epoch( dataset_test, model_path,size, npoints): 197 | net = loadnet(npoints,model_path) 198 | OUT = OutIntensity().to('cuda') 199 | # Load data 200 | database = MyDatasets(size=size,database=dataset_test) 201 | dbloader = DataLoader(database, batch_size=1, shuffle=False, num_workers=0, pin_memory=False) 202 | pred = predict(dbloader,OUT,net) 203 | 204 | def main(): 205 | global args 206 | args = parser.parse_args() 207 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda) 208 | test_epoch(dataset_test=args.dataset_test,model_path=args.model_path,size=args.size,npoints=args.K) 209 | 210 | if __name__ == '__main__': 211 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch, math 2 | import cv2 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | class OutIntensity(torch.nn.Module): 8 | """Infer AU intensity from a heatmap: :(x, y) = argmax H """ 9 | def __init__(self): 10 | super(OutIntensity,self).__init__() 11 | 12 | def forward(self,x): 13 | batch_size = x.shape[0] 14 | num_points = x.shape[1] 15 | width = x.shape[2] 16 | x_ = x.to('cpu').detach().numpy().astype(np.float32).copy() 17 | heatmaps_reshaped = x_.reshape((batch_size, num_points, -1)) 18 | intensity = heatmaps_reshaped.max(axis=2) 19 | return intensity/255.0 20 | --------------------------------------------------------------------------------