├── README.md ├── config ├── 1_belfusion_vae.yaml └── 2_belfusion_ldm.yaml ├── dataset.py ├── evaluate.py ├── external ├── FaceVerse │ ├── FaceVerseModel.py │ ├── LICENSE │ ├── ModelRenderer.py │ └── __init__.py └── PIRender │ ├── LICENSE.md │ ├── __init__.py │ ├── base_function.py │ ├── face_model.py │ └── flow_util.py ├── metric ├── ACC.py ├── FRC.py ├── FRD.py ├── FRDvs.py ├── FRVar.py ├── S_MSE.py ├── TLCC.py ├── __init__.py └── metric.py ├── model ├── BasicBlock.py ├── TransformerVAE.py ├── __init__.py ├── belfusion │ ├── diffusion.py │ ├── matchers.py │ ├── mlp_diffae.py │ ├── resample.py │ ├── rnn.py │ └── torch.py └── losses.py ├── regnn ├── datasets │ ├── __init__.py │ └── datasets.py ├── evaluation.py ├── feature_extraction.py ├── metric │ ├── ACC.py │ ├── FRC.py │ ├── FRD.py │ ├── FRDvs.py │ ├── FRVar.py │ ├── S_MSE.py │ ├── TLCC.py │ ├── __init__.py │ └── metric.py ├── models │ ├── LipschitzGraph.py │ ├── MULTModel.py │ ├── __init__.py │ ├── cognitive.py │ ├── distribution.py │ ├── mhp.py │ ├── model_utils.py │ ├── multimodel │ │ ├── multihead_attention.py │ │ ├── position_embedding.py │ │ └── transformer.py │ ├── perceptual.py │ ├── swin_transformer.py │ └── torchvggish │ │ ├── mel_features.py │ │ ├── vggish.py │ │ ├── vggish_input.py │ │ └── vggish_params.py ├── scripts │ ├── inference.sh │ └── train.sh ├── tool │ └── matrix_split.py ├── train.py ├── trainers.py └── utils │ ├── __init__.py │ ├── compute_distance_fun.py │ ├── evaluate.py │ ├── file.py │ ├── logging.py │ ├── loss.py │ ├── lr_scheduler.py │ └── meters.py ├── render.py ├── requirements.txt ├── run_baselines.py ├── tool ├── audio_visual_clip.py ├── data_indices.csv └── matrix_split.py ├── train.py ├── train_belfusion.py └── utils.py /config/1_belfusion_vae.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: AutoencoderRNN_VAE_v2 3 | args: 4 | emotion_dim: 25 5 | coeff_3dmm_dim: 58 6 | 7 | emb_dims: [64, 64] 8 | num_layers: 2 9 | hidden_dim: 128 10 | z_dim: 128 11 | feature_dim: 128 12 | rnn_type: 'gru' 13 | dropout: 0.0 14 | 15 | window_size: 50 16 | seq_len: 750 17 | 18 | loss: 19 | type: MSELoss_AE_v2 20 | args: 21 | w_mse: 1 22 | w_kld: 0.00001 23 | w_coeff: 1 24 | 25 | optimizer: 26 | lr: 0.001 27 | weight_decay: 5e-4 28 | 29 | trainer: 30 | mode: autoencode 31 | epochs: 1000 32 | resume: 33 | out_dir: ./results 34 | save_period: 2 35 | val_period: 5 36 | 37 | dataset: 38 | dataset_path: ./data 39 | split: train 40 | 41 | img_size: 256 42 | crop_size: 224 43 | clip_length: 750 44 | 45 | batch_size: 32 46 | shuffle: True 47 | num_workers: 4 48 | 49 | load_video: false 50 | load_audio: false 51 | load_emotion: true 52 | load_3dmm: true 53 | 54 | validation_dataset: 55 | dataset_path: ./data/ 56 | split: val 57 | 58 | img_size: 256 59 | crop_size: 224 60 | clip_length: 750 61 | 62 | batch_size: 32 63 | shuffle: False 64 | num_workers: 4 65 | 66 | load_video: false 67 | load_audio: false 68 | load_emotion: true 69 | load_3dmm: true 70 | 71 | test_dataset: 72 | dataset_path: ./data 73 | split: test 74 | 75 | img_size: 256 76 | crop_size: 224 77 | clip_length: 750 78 | 79 | batch_size: 32 80 | shuffle: False 81 | num_workers: 4 82 | 83 | load_video: false 84 | load_audio: false 85 | load_emotion: true 86 | load_3dmm: true -------------------------------------------------------------------------------- /config/2_belfusion_ldm.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: LatentMLPMatcher 3 | args: 4 | diffusion_steps: 10 5 | k: 10 6 | 7 | num_hid_channels: 1024 8 | num_layers: 10 9 | 10 | emotion_dim: 25 11 | 12 | dropout: 0.0 13 | emb_length: 128 14 | emb_emotion_path: results/All_VAEv2_W50/checkpoint_999.pth 15 | emb_preprocessing: "normalize" 16 | freeze_emotion_encoder: true 17 | 18 | window_size: 50 19 | seq_len: 750 20 | online: false 21 | 22 | loss: 23 | type: BeLFusionLoss 24 | args: 25 | losses: 26 | - L1Loss 27 | - MSELoss 28 | losses_multipliers: 29 | - 10 30 | - 10 31 | losses_decoded: 32 | - false 33 | - true 34 | 35 | optimizer: 36 | lr: 0.0001 37 | weight_decay: 5e-4 38 | 39 | trainer: 40 | mode: predict 41 | epochs: 100 42 | resume: 43 | out_dir: ./results 44 | save_period: 50 45 | val_period: 10 46 | 47 | dataset: 48 | dataset_path: ./data 49 | split: train 50 | 51 | img_size: 256 52 | crop_size: 224 53 | clip_length: 750 54 | 55 | batch_size: 32 56 | shuffle: True 57 | num_workers: 4 58 | 59 | load_video: false 60 | load_audio: false 61 | load_emotion: true 62 | load_3dmm: true 63 | 64 | validation_dataset: 65 | dataset_path: ./data 66 | split: val 67 | 68 | img_size: 256 69 | crop_size: 224 70 | clip_length: 750 71 | 72 | batch_size: 32 73 | shuffle: False 74 | num_workers: 4 75 | 76 | load_video: false 77 | load_audio: false 78 | load_emotion: true 79 | load_3dmm: true 80 | 81 | test_dataset: 82 | dataset_path: ./data 83 | split: test 84 | 85 | img_size: 256 86 | crop_size: 224 87 | clip_length: 750 88 | 89 | batch_size: 32 90 | shuffle: False 91 | num_workers: 4 92 | 93 | load_video: false 94 | load_audio: false 95 | load_emotion: true 96 | load_3dmm: true -------------------------------------------------------------------------------- /external/FaceVerse/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Evelyn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /external/FaceVerse/ModelRenderer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from pytorch3d.structures import Meshes 5 | from pytorch3d.renderer import ( 6 | look_at_view_transform, 7 | FoVPerspectiveCameras, 8 | PointLights, 9 | RasterizationSettings, 10 | MeshRenderer, 11 | MeshRasterizer, 12 | HardFlatShader, 13 | TexturesVertex, 14 | blending 15 | ) 16 | 17 | class ModelRenderer: 18 | def __init__(self, focal=1315, img_size=224, device='cuda:0'): 19 | self.img_size = img_size 20 | self.focal = focal 21 | self.device = device 22 | 23 | self.alb_renderer = self._get_renderer(albedo=True) 24 | self.sha_renderer = self._get_renderer(albedo=False) 25 | 26 | def _get_renderer(self, albedo=True): 27 | R, T = look_at_view_transform(10, 0, 0) # camera's position 28 | cameras = FoVPerspectiveCameras(device=self.device, R=R, T=T, znear=0.01, zfar=50, 29 | fov=2 * np.arctan(self.img_size // 2 / self.focal) * 180. / np.pi) 30 | 31 | if albedo: 32 | lights = PointLights(device=self.device, location=[[0.0, 0.0, 1e5]], 33 | ambient_color=[[1, 1, 1]], 34 | specular_color=[[0., 0., 0.]], diffuse_color=[[0., 0., 0.]]) 35 | else: 36 | lights = PointLights(device=self.device, location=[[0.0, 0.0, 1e5]], 37 | ambient_color=[[0.1, 0.1, 0.1]], 38 | specular_color=[[0.0, 0.0, 0.0]], diffuse_color=[[0.95, 0.95, 0.95]]) 39 | 40 | raster_settings = RasterizationSettings( 41 | image_size=self.img_size, 42 | blur_radius=0.0, 43 | faces_per_pixel=1, 44 | ) 45 | blend_params = blending.BlendParams(background_color=[0, 0, 0]) 46 | 47 | renderer = MeshRenderer( 48 | rasterizer=MeshRasterizer( 49 | cameras=cameras, 50 | raster_settings=raster_settings 51 | ), 52 | shader=HardFlatShader( 53 | device=self.device, 54 | cameras=cameras, 55 | lights=lights, 56 | blend_params=blend_params 57 | ) 58 | ) 59 | return renderer 60 | 61 | -------------------------------------------------------------------------------- /external/FaceVerse/__init__.py: -------------------------------------------------------------------------------- 1 | from .FaceVerseModel import FaceVerseModel 2 | import numpy as np 3 | 4 | def get_faceverse(**kargs): 5 | faceverse_dict = np.load('external/FaceVerse/data/faceverse_simple_v2.npy', allow_pickle=True).item() 6 | faceverse_model = FaceVerseModel(faceverse_dict, **kargs) 7 | return faceverse_model, faceverse_dict 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /external/PIRender/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_model import FaceGenerator -------------------------------------------------------------------------------- /external/PIRender/face_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from external.PIRender import flow_util 9 | from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder 10 | 11 | 12 | 13 | 14 | 15 | class FaceGenerator(nn.Module): 16 | def __init__(self, ): 17 | super(FaceGenerator, self).__init__() 18 | self.mapping_net = MappingNet() 19 | self.warpping_net = WarpingNet() 20 | self.editing_net = EditingNet() 21 | 22 | def forward( 23 | self, 24 | input_image, 25 | driving_source, 26 | stage=None 27 | ): 28 | # if stage == 'warp': 29 | # descriptor = self.mapping_net(driving_source) 30 | # output = self.warpping_net(input_image, descriptor) 31 | # else: 32 | descriptor = self.mapping_net(driving_source) 33 | output = self.warpping_net(input_image, descriptor) 34 | output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) 35 | return output 36 | 37 | 38 | 39 | class MappingNet(nn.Module): 40 | def __init__(self, flame_coeff_nc = 58, coeff_nc = 73, descriptor_nc = 256, layer = 3): 41 | super( MappingNet, self).__init__() 42 | 43 | self.layer = layer 44 | nonlinearity = nn.LeakyReLU(0.1) 45 | 46 | self.pre = torch.nn.Conv1d(flame_coeff_nc, coeff_nc, kernel_size=1, padding=0, bias=True) 47 | 48 | self.first = nn.Sequential( 49 | torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) 50 | 51 | for i in range(layer): 52 | net = nn.Sequential(nonlinearity, 53 | torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) 54 | setattr(self, 'encoder' + str(i), net) 55 | 56 | self.pooling = nn.AdaptiveAvgPool1d(1) 57 | self.output_nc = descriptor_nc 58 | 59 | def forward(self, input_3dmm): 60 | out = self.pre(input_3dmm) 61 | out = self.first(out) 62 | # print("out:", out.shape) 63 | for i in range(self.layer): 64 | model = getattr(self, 'encoder' + str(i)) 65 | out = model(out) + out[:,:,3:-3] 66 | out = self.pooling(out) 67 | return out 68 | 69 | 70 | 71 | class WarpingNet(nn.Module): 72 | def __init__( 73 | self, 74 | image_nc = 3, 75 | descriptor_nc = 256, 76 | base_nc = 32, 77 | max_nc = 256, 78 | encoder_layer = 5 , 79 | decoder_layer = 3, 80 | use_spect = False 81 | ): 82 | super( WarpingNet, self).__init__() 83 | 84 | nonlinearity = nn.LeakyReLU(0.1) 85 | norm_layer = functools.partial(LayerNorm2d, affine=True) 86 | kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} 87 | 88 | self.descriptor_nc = descriptor_nc 89 | self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, 90 | max_nc, encoder_layer, decoder_layer, **kwargs) 91 | 92 | self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), 93 | nonlinearity, 94 | nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) 95 | 96 | self.pool = nn.AdaptiveAvgPool2d(1) 97 | 98 | def forward(self, input_image, descriptor): 99 | final_output={} 100 | output = self.hourglass(input_image, descriptor) 101 | final_output['flow_field'] = self.flow_out(output) 102 | 103 | deformation = flow_util.convert_flow_to_deformation(final_output['flow_field']) 104 | final_output['warp_image'] = flow_util.warp_image(input_image, deformation) 105 | return final_output 106 | 107 | 108 | 109 | class EditingNet(nn.Module): 110 | def __init__( 111 | self, 112 | image_nc = 3, 113 | descriptor_nc = 256, 114 | layer = 3, 115 | base_nc = 64, 116 | max_nc = 256, 117 | num_res_blocks = 2, 118 | use_spect = False): 119 | super(EditingNet, self).__init__() 120 | 121 | nonlinearity = nn.LeakyReLU(0.1) 122 | norm_layer = functools.partial(LayerNorm2d, affine=True) 123 | kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} 124 | self.descriptor_nc = descriptor_nc 125 | 126 | # encoder part 127 | self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) 128 | self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) 129 | 130 | def forward(self, input_image, warp_image, descriptor): 131 | x = torch.cat([input_image, warp_image], 1) 132 | x = self.encoder(x) 133 | gen_image = self.decoder(x, descriptor) 134 | return gen_image 135 | 136 | 137 | # if __name__ == '__main__': 138 | # # model = FaceGenerator() 139 | # # x = torch.randn(1,3,256,256) 140 | # # latent = torch.randn(1,73, 100) 141 | # # output = model(x, latent) 142 | # # 143 | # # checkpoint = torch.load('/home/luocheng/Reaction/PIRender-main/checkpoints/face/epoch_00190_iteration_000400000_checkpoint.pt')['net_G_ema'] 144 | # # 145 | # # model.load_state_dict(checkpoint) 146 | # 147 | # # print(checkpoint['net_G_ema']) 148 | # 149 | # # print(output) 150 | # # a = np.load('/home/luocheng/Datasets/S-L/3D_files/NoXI/001_2016-03-17_Paris/Expert_video/1.npy') 151 | # # print(a.shape) -------------------------------------------------------------------------------- /external/PIRender/flow_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def convert_flow_to_deformation(flow): 4 | r"""convert flow fields to deformations. 5 | 6 | Args: 7 | flow (tensor): Flow field obtained by the model 8 | Returns: 9 | deformation (tensor): The deformation used for warpping 10 | """ 11 | b,c,h,w = flow.shape 12 | flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) 13 | grid = make_coordinate_grid(flow) 14 | deformation = grid + flow_norm.permute(0,2,3,1) 15 | return deformation 16 | 17 | def make_coordinate_grid(flow): 18 | r"""obtain coordinate grid with the same size as the flow filed. 19 | 20 | Args: 21 | flow (tensor): Flow field obtained by the model 22 | Returns: 23 | grid (tensor): The grid with the same size as the input flow 24 | """ 25 | b,c,h,w = flow.shape 26 | 27 | x = torch.arange(w).to(flow) 28 | y = torch.arange(h).to(flow) 29 | 30 | x = (2 * (x / (w - 1)) - 1) 31 | y = (2 * (y / (h - 1)) - 1) 32 | 33 | yy = y.view(-1, 1).repeat(1, w) 34 | xx = x.view(1, -1).repeat(h, 1) 35 | 36 | meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) 37 | meshed = meshed.expand(b, -1, -1, -1) 38 | return meshed 39 | 40 | 41 | def warp_image(source_image, deformation): 42 | r"""warp the input image according to the deformation 43 | 44 | Args: 45 | source_image (tensor): source images to be warpped 46 | deformation (tensor): deformations used to warp the images; value in range (-1, 1) 47 | Returns: 48 | output (tensor): the warpped images 49 | """ 50 | _, h_old, w_old, _ = deformation.shape 51 | _, _, h, w = source_image.shape 52 | if h_old != h or w_old != w: 53 | deformation = deformation.permute(0, 3, 1, 2) 54 | deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') 55 | deformation = deformation.permute(0, 2, 3, 1) 56 | return torch.nn.functional.grid_sample(source_image, deformation) -------------------------------------------------------------------------------- /metric/ACC.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.distance import pdist 2 | import numpy as np 3 | import torch 4 | import os 5 | -------------------------------------------------------------------------------- /metric/FRC.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import pdist 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | from tslearn.metrics import dtw 7 | from functools import partial 8 | import multiprocessing as mp 9 | 10 | 11 | def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue, *, 12 | dtype=None): 13 | 14 | if bias is not np._NoValue or ddof is not np._NoValue: 15 | # 2015-03-15, 1.10 16 | warnings.warn('bias and ddof have no effect and are deprecated', 17 | DeprecationWarning, stacklevel=3) 18 | c = np.cov(x, y, rowvar, dtype=dtype) 19 | 20 | try: 21 | d = np.diag(c) 22 | except ValueError: 23 | # scalar covariance 24 | # nan if incorrect value (nan, inf, 0), 1 otherwise 25 | return c / c 26 | stddev = np.sqrt(d.real) 27 | c /= stddev[:, None] 28 | c /= stddev[None, :] 29 | c = np.nan_to_num(c) 30 | 31 | # Clip real and imaginary parts to [-1, 1]. This does not guarantee 32 | # abs(a[i,j]) <= 1 for complex arrays, but is the best we can do without 33 | # excessive work. 34 | np.clip(c.real, -1, 1, out=c.real) 35 | if np.iscomplexobj(c): 36 | np.clip(c.imag, -1, 1, out=c.imag) 37 | return c 38 | 39 | def concordance_correlation_coefficient(y_true, y_pred, 40 | sample_weight=None, 41 | multioutput='uniform_average'): 42 | """Concordance correlation coefficient. 43 | The concordance correlation coefficient is a measure of inter-rater agreement. 44 | It measures the deviation of the relationship between predicted and true values 45 | from the 45 degree angle. 46 | Read more: https://en.wikipedia.org/wiki/Concordance_correlation_coefficient 47 | Original paper: Lawrence, I., and Kuei Lin. "A concordance correlation coefficient to evaluate reproducibility." Biometrics (1989): 255-268. 48 | Parameters 49 | ---------- 50 | y_true : array-like of shape = (n_samples) or (n_samples, n_outputs) 51 | Ground truth (correct) target values. 52 | y_pred : array-like of shape = (n_samples) or (n_samples, n_outputs) 53 | Estimated target values. 54 | Returns 55 | ------- 56 | loss : A float in the range [-1,1]. A value of 1 indicates perfect agreement 57 | between the true and the predicted values. 58 | Examples 59 | -------- 60 | >>> from sklearn.metrics import concordance_correlation_coefficient 61 | >>> y_true = [3, -0.5, 2, 7] 62 | >>> y_pred = [2.5, 0.0, 2, 8] 63 | >>> concordance_correlation_coefficient(y_true, y_pred) 64 | 0.97678916827853024 65 | """ 66 | if len(y_true.shape) >1: 67 | ccc_list = [] 68 | for i in range(y_true.shape[1]): 69 | cor = corrcoef(y_true[:,i], y_pred[:, i])[0][1] 70 | mean_true = np.mean(y_true[:,i]) 71 | 72 | mean_pred = np.mean(y_pred[:,i]) 73 | 74 | 75 | var_true = np.var(y_true[:,i]) 76 | var_pred = np.var(y_pred[:,i]) 77 | 78 | sd_true = np.std(y_true[:,i]) 79 | sd_pred = np.std(y_pred[:,i]) 80 | 81 | numerator = 2 * cor * sd_true * sd_pred 82 | 83 | denominator = var_true + var_pred + (mean_true - mean_pred) ** 2 84 | 85 | ccc = numerator / (denominator + 1e-8) 86 | 87 | ccc_list.append(ccc) 88 | ccc = np.mean(ccc_list) 89 | else: 90 | cor = np.corrcoef(y_true, y_pred)[0][1] 91 | mean_true = np.mean(y_true) 92 | mean_pred = np.mean(y_pred) 93 | 94 | var_true = np.var(y_true) 95 | var_pred = np.var(y_pred) 96 | 97 | sd_true = np.std(y_true) 98 | sd_pred = np.std(y_pred) 99 | 100 | numerator = 2 * cor * sd_true * sd_pred 101 | 102 | denominator = var_true + var_pred + (mean_true - mean_pred) ** 2 103 | ccc = numerator / (denominator + 1e-8) 104 | return ccc 105 | 106 | 107 | 108 | 109 | 110 | def compute_FRC(args, pred, listener_em, val_test='val'): 111 | pred = pred 112 | listener_em = listener_em 113 | if val_test == 'val': 114 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_val.npy')) 115 | else: 116 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_test.npy')) 117 | 118 | all_FRC_list = [] 119 | for i in range(pred.shape[1]): 120 | FRC_list = [] 121 | for k in range(pred.shape[0]): 122 | speaker_neighbour_index = np.argwhere(speaker_neighbour_matrix[k] == 1).reshape(-1) 123 | speaker_neighbour_index_len = len(speaker_neighbour_index) 124 | ccc_list = [] 125 | for n_index in range(speaker_neighbour_index_len): 126 | 127 | ''' 128 | listener_em order :[listener1, listener2, listener3, ....., listener_n, speaker1, speaker2, speaker3, ....., speaker_n] 129 | listener1: [1, emotion_dim] 130 | 131 | listener_em[speaker_neighbour_index[n_index]]: 132 | 1. speaker_neighbour_index[n_index]: speaker_j (with similar emotion as the speaker_k) 133 | 2. listener_em[speaker_neighbour_index[n_index]]: emotion features of listener_j (speaker_j -> listener_j) 134 | So we can get an additional GT listener_j to listener_k (i.e., speaker_j -> listener_k) 135 | ''' 136 | 137 | similar_listener_emotion = listener_em[speaker_neighbour_index[n_index]] 138 | ccc = concordance_correlation_coefficient(similar_listener_emotion.numpy(), pred[k,i].numpy()) 139 | ccc_list.append(ccc) 140 | max_ccc = max(ccc_list) 141 | FRC_list.append(max_ccc) 142 | all_FRC_list.append(np.mean(FRC_list)) 143 | return sum(all_FRC_list) 144 | 145 | 146 | 147 | def _func(k_neighbour_matrix, k_pred, em=None): 148 | neighbour_index = np.argwhere(k_neighbour_matrix == 1).reshape(-1) 149 | neighbour_index_len = len(neighbour_index) 150 | max_ccc_sum = 0 151 | for i in range(k_pred.shape[0]): 152 | ccc_list = [] 153 | for n_index in range(neighbour_index_len): 154 | emotion = em[neighbour_index[n_index]] 155 | ccc = concordance_correlation_coefficient(emotion, k_pred[i]) 156 | ccc_list.append(ccc) 157 | max_ccc_sum += max(ccc_list) 158 | return max_ccc_sum 159 | 160 | 161 | 162 | def compute_FRC_mp(args, pred, em, val_test='val', p=1): 163 | # pred: N 10 750 25 164 | # speaker: N 750 25 165 | if val_test == 'val': 166 | neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_val.npy')) 167 | else: 168 | neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_test.npy')) 169 | 170 | FRC_list = [] 171 | with mp.Pool(processes=p) as pool: 172 | # use map 173 | _func_partial = partial(_func, em=em.numpy()) 174 | FRC_list += pool.starmap(_func_partial, zip(neighbour_matrix, pred.numpy())) 175 | return np.mean(FRC_list) 176 | -------------------------------------------------------------------------------- /metric/FRD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tslearn.metrics import dtw 4 | from functools import partial 5 | import multiprocessing as mp 6 | 7 | 8 | def _func(k_neighbour_matrix, k_pred, em=None): 9 | neighbour_index = np.argwhere(k_neighbour_matrix == 1).reshape(-1) 10 | neighbour_index_len = len(neighbour_index) 11 | min_dwt_sum = 0 12 | for i in range(k_pred.shape[0]): 13 | dwt_list = [] 14 | for n_index in range(neighbour_index_len): 15 | emotion = em[neighbour_index[n_index]] 16 | res = 0 17 | for st, ed, weight in [(0, 15, 1 / 15), (15, 17, 1), (17, 25, 1 / 8)]: 18 | res += weight * dtw(k_pred[i].astype(np.float32)[:, st: ed], emotion.astype(np.float32)[:, st: ed]) 19 | dwt_list.append(res) 20 | min_dwt_sum += min(dwt_list) 21 | return min_dwt_sum 22 | 23 | 24 | def compute_FRD_mp(args, pred, em, val_test='val', p=4): 25 | # pred: N 10 750 25 26 | # speaker: N 750 25 27 | 28 | if val_test == 'val': 29 | neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_val.npy')) 30 | else: 31 | neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_test.npy')) 32 | 33 | FRD_list = [] 34 | with mp.Pool(processes=p) as pool: 35 | # use map 36 | _func_partial = partial(_func, em=em.numpy()) 37 | FRD_list += pool.starmap(_func_partial, zip(neighbour_matrix, pred.numpy())) 38 | 39 | return np.mean(FRD_list) 40 | 41 | 42 | 43 | def compute_FRD(args, pred, listener_em, val_test='val'): 44 | if val_test == 'val': 45 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_val.npy')) 46 | else: 47 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_test.npy')) 48 | all_FRD_list = [] 49 | for i in range(pred.shape[1]): 50 | FRD_list = [] 51 | for k in range(pred.shape[0]): 52 | speaker_neighbour_index = np.argwhere(speaker_neighbour_matrix[k] == 1).reshape(-1) 53 | speaker_neighbour_index_len = len(speaker_neighbour_index) 54 | dwt_list = [] 55 | for n_index in range(speaker_neighbour_index_len): 56 | emotion = listener_em[speaker_neighbour_index[n_index]] 57 | res = 0 58 | for st, ed, weight in [(0, 15, 1 / 15), (15, 17, 1), (17, 25, 1 / 8)]: 59 | res += weight * dtw(pred[k, i].numpy().astype(np.float32)[:, st: ed], 60 | emotion.numpy().astype(np.float32)[:, st: ed]) 61 | dwt_list.append(res) 62 | min_dwt = min(dwt_list) 63 | FRD_list.append(min_dwt) 64 | all_FRD_list.append(np.mean(FRD_list)) 65 | return sum(all_FRD_list) 66 | -------------------------------------------------------------------------------- /metric/FRDvs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_FRDvs(preds): 5 | # preds: (N, 10, 750, 25) 6 | preds_ = preds.reshape(preds.shape[0], preds.shape[1], -1) 7 | preds_ = preds_.transpose(0, 1) 8 | # preds_: (10, N, 750*25) 9 | dist = torch.pow(torch.cdist(preds_, preds_), 2) 10 | # dist: (10, N, N) 11 | dist = torch.sum(dist) / (preds.shape[0] * (preds.shape[0] - 1) * preds.shape[1]) 12 | return dist / preds_.shape[-1] 13 | -------------------------------------------------------------------------------- /metric/FRVar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_FRVar(preds): 4 | if len(preds.shape) == 3: 5 | # preds: (10, 750, 25) 6 | var = torch.var(preds, dim=1) 7 | return torch.mean(var) 8 | elif len(preds.shape) == 4: 9 | # preds: (N, 10, 750, 25) 10 | var = torch.var(preds, dim=2) 11 | return torch.mean(var) 12 | -------------------------------------------------------------------------------- /metric/S_MSE.py: -------------------------------------------------------------------------------- 1 | # from scipy.spatial.distance import pdist 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | 7 | def compute_s_mse(preds): 8 | # preds: (B, 10, 750, 25) 9 | dist = 0 10 | for b in range(preds.shape[0]): 11 | preds_item = preds[b] 12 | if preds_item.shape[0] == 1: 13 | return 0.0 14 | preds_item_ = preds_item.reshape(preds_item.shape[0], -1) 15 | dist_ = torch.pow(torch.cdist(preds_item_, preds_item_), 2) 16 | dist_ = torch.sum(dist_) / (preds_item.shape[0] * (preds_item.shape[0] - 1) * preds_item_.shape[1]) 17 | dist += dist_ 18 | return dist / preds.shape[0] -------------------------------------------------------------------------------- /metric/TLCC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import multiprocessing as mp 4 | 5 | def crosscorr(datax, datay, lag=0, dim=25): 6 | pcc_list = [] 7 | for i in range(dim): 8 | cn_1, cn_2 = shift(datax[:, i], datay[:, i], lag) 9 | pcc_i = np.corrcoef(cn_1, cn_2)[0, 1] 10 | # pcc_i = torch.corrcoef(torch.stack([cn_1, cn_2], dim=0).float())[0, 1] 11 | pcc_list.append(pcc_i.item()) 12 | return torch.mean(torch.Tensor(pcc_list)) 13 | 14 | 15 | def calculate_tlcc(pred, sp, seconds=2, fps=25): 16 | rs = [crosscorr(pred, sp, lag, sp.shape[-1]) for lag in range(-int(seconds * fps - 1), int(seconds * fps))] 17 | peak = max(rs) 18 | center = rs[len(rs) // 2] 19 | offset = len(rs) // 2 - torch.argmax(torch.Tensor(rs)) 20 | return peak, center, offset 21 | 22 | def compute_TLCC(pred, speaker): 23 | # pred: N 10 750 25 24 | # speaker: N 750 25 25 | offset_list = [] 26 | for k in range(speaker.shape[0]): 27 | pred_item = pred[k] 28 | sp_item = speaker[k] 29 | for i in range(pred_item.shape[0]): 30 | peak, center, offset = calculate_tlcc(pred_item[i].float().numpy(), sp_item.float().numpy()) 31 | offset_list.append(torch.abs(offset).item()) 32 | return torch.mean(torch.Tensor(offset_list)).item() 33 | 34 | 35 | def _func(pred_item, sp_item): 36 | for i in range(pred_item.shape[0]): 37 | peak, center, offset = calculate_tlcc(pred_item[i], sp_item) 38 | return torch.abs(offset).item() 39 | 40 | def compute_TLCC_mp(pred, speaker, p=8): 41 | # pred: N 10 750 25 42 | # speaker: N 750 25 43 | offset_list = [] 44 | # process each speaker in parallel 45 | np.seterr(divide='ignore', invalid='ignore') 46 | 47 | with mp.Pool(processes=p) as pool: 48 | # use map 49 | offset_list += pool.starmap(_func, zip(pred.float().numpy(), speaker.float().numpy())) 50 | return torch.mean(torch.Tensor(offset_list)).item() 51 | 52 | 53 | def SingleTLCC(pred, speaker): 54 | # pred: 10 750 25 55 | # speaker: 750 25 56 | offset_list = [] 57 | for i in range(pred.shape[0]): 58 | peak, center, offset = calculate_tlcc(pred[i].float(), speaker.float()) 59 | offset_list.append(torch.abs(offset).item()) 60 | return torch.mean(torch.Tensor(offset_list)).item() 61 | 62 | 63 | def shift(x, y, lag): 64 | if lag > 0: 65 | return x[lag:], y[:-lag] 66 | elif lag < 0: 67 | return x[:lag], y[-lag:] 68 | else: 69 | return x, y 70 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .FRC import compute_FRC_mp, compute_FRC 2 | from .FRD import compute_FRD_mp, compute_FRD 3 | from .FRDvs import compute_FRDvs 4 | from .FRVar import compute_FRVar 5 | from .S_MSE import compute_s_mse 6 | from .TLCC import compute_TLCC, compute_TLCC_mp 7 | -------------------------------------------------------------------------------- /metric/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def s_mse(preds): 5 | # preds: (B, 10, 750, 25) 6 | dist = 0 7 | for b in range(preds.shape[0]): 8 | preds_item = preds[b] 9 | if preds_item.shape[0] == 1: 10 | return 0.0 11 | preds_item_ = preds_item.reshape(preds_item.shape[0], -1) 12 | dist_ = torch.pow(torch.cdist(preds_item_, preds_item_), 2) 13 | dist_ = torch.sum(dist_) / (preds_item.shape[0] * (preds_item.shape[0] - 1) * preds_item_.shape[1]) 14 | dist += dist_ 15 | return dist / preds.shape[0] 16 | 17 | 18 | def FRVar(preds): 19 | if len(preds.shape) == 3: 20 | # preds: (10, 750, 25) 21 | var = torch.var(preds, dim=1) 22 | return torch.mean(var) 23 | elif len(preds.shape) == 4: 24 | # preds: (N, 10, 750, 25) 25 | var = torch.var(preds, dim=2) 26 | return torch.mean(var) 27 | 28 | 29 | 30 | def FRDvs(preds): 31 | # preds: (N, 10, 750, 25) 32 | preds_ = preds.reshape(preds.shape[0], preds.shape[1], -1) 33 | preds_ = preds_.transpose(0, 1) 34 | # preds_: (10, N, 750*25) 35 | dist = torch.pow(torch.cdist(preds_, preds_), 2) 36 | # dist: (10, N, N) 37 | dist = torch.sum(dist) / (preds.shape[0] * (preds.shape[0] - 1) * preds.shape[1]) 38 | return dist / preds_.shape[-1] 39 | 40 | 41 | from scipy.spatial.distance import pdist 42 | import numpy as np 43 | import torch 44 | import os 45 | 46 | def compute_FRVar(pred): 47 | FRVar_list = [] 48 | for k in range(pred.shape[0]): 49 | pred_item = pred[k] 50 | for i in range(0, pred_item.shape[0]): 51 | var = np.mean(np.var(pred_item[i].numpy().astype(np.float32), axis=0)) 52 | FRVar_list.append(var) 53 | return np.mean(FRVar_list) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /model/BasicBlock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | 6 | class ConvBlock(nn.Module): 7 | def __init__(self, in_planes=3, planes=128): 8 | super(ConvBlock, self).__init__() 9 | self.planes = planes 10 | self.conv1 = nn.Conv3d(in_planes, planes // 4, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), 11 | bias=False) 12 | self.in1 = nn.InstanceNorm3d(planes // 4) 13 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 0, 0)) 14 | 15 | self.conv2 = nn.Conv3d(planes // 4, planes, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), 16 | bias=False) 17 | self.in2 = nn.InstanceNorm3d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | self.conv3 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False) 21 | self.in3 = nn.InstanceNorm3d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | self.conv4 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False) 25 | self.in4 = nn.InstanceNorm3d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | 28 | 29 | self.conv5 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False) 30 | self.in5 = nn.InstanceNorm3d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | 33 | def forward(self, x): 34 | """ 35 | input: 36 | speaker_video_frames x: (batch_size, 3, seq_len, img_size, img_size) 37 | 38 | output: 39 | speaker_temporal_tokens y: (batch_size, token_dim, seq_len) 40 | 41 | """ 42 | 43 | x = self.relu(self.in1(self.conv1(x))) 44 | x = self.maxpool(x) 45 | x = self.relu(self.in2(self.conv2(x))) 46 | x = self.relu(self.in3(self.conv3(x))) 47 | x = x + self.relu(self.in4(self.conv4(x))) 48 | x = self.relu(self.in5(self.conv5(x))) 49 | 50 | y = x.mean(dim=-1).mean(dim=-1) 51 | return y 52 | 53 | 54 | 55 | 56 | class PositionalEncoding(nn.Module): 57 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=True): 58 | super().__init__() 59 | self.batch_first = batch_first 60 | 61 | self.dropout = nn.Dropout(p=dropout) 62 | 63 | pe = torch.zeros(max_len, d_model) 64 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 65 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 66 | pe[:, 0::2] = torch.sin(position * div_term) 67 | pe[:, 1::2] = torch.cos(position * div_term) 68 | pe = pe.unsqueeze(0).transpose(0, 1) 69 | 70 | self.register_buffer('pe', pe) 71 | 72 | def forward(self, x): 73 | # not used in the final model 74 | if self.batch_first: 75 | x = x + self.pe.permute(1, 0, 2)[:, :x.shape[1], :] 76 | else: 77 | x = x + self.pe[:x.shape[0], :] 78 | return self.dropout(x) 79 | 80 | 81 | 82 | def lengths_to_mask(lengths, device): 83 | lengths = torch.tensor(lengths, device=device) 84 | max_len = max(lengths) 85 | mask = torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1) 86 | return mask 87 | 88 | 89 | 90 | # Temporal Bias, inspired by ALiBi: https://github.com/ofirpress/attention_with_linear_biases 91 | def init_biased_mask(n_head, max_seq_len, period): 92 | def get_slopes(n): 93 | def get_slopes_power_of_2(n): 94 | start = (2**(-2**-(math.log2(n)-3))) 95 | ratio = start 96 | return [start*ratio**i for i in range(n)] 97 | if math.log2(n).is_integer(): 98 | return get_slopes_power_of_2(n) 99 | else: 100 | closest_power_of_2 = 2**math.floor(math.log2(n)) 101 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2] 102 | slopes = torch.Tensor(get_slopes(n_head)) 103 | bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period) 104 | bias = - torch.flip(bias,dims=[0]) 105 | alibi = torch.zeros(max_seq_len, max_seq_len) 106 | for i in range(max_seq_len): 107 | alibi[i, :i+1] = bias[-(i+1):] 108 | alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) 109 | mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1) 110 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 111 | mask = mask.unsqueeze(0) + alibi 112 | return mask -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .TransformerVAE import TransformerVAE 2 | from .belfusion.matchers import LatentMLPMatcher 3 | from .belfusion.rnn import AutoencoderRNN_VAE_v2 -------------------------------------------------------------------------------- /model/belfusion/mlp_diffae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import NamedTuple, Tuple 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import init 9 | 10 | 11 | 12 | # https://github.com/phizaz/diffae/ 13 | 14 | DTYPES = { 15 | 'float16': torch.float16, 16 | 'float32': torch.float32, 17 | 'float64': torch.float64 18 | } 19 | 20 | class LatentNetType(Enum): 21 | none = 'none' 22 | # injecting inputs into the hidden layers 23 | skip = 'skip' 24 | 25 | 26 | class Activation(Enum): 27 | none = 'none' 28 | relu = 'relu' 29 | lrelu = 'lrelu' 30 | silu = 'silu' 31 | tanh = 'tanh' 32 | 33 | def get_act(self): 34 | if self == Activation.none: 35 | return nn.Identity() 36 | elif self == Activation.relu: 37 | return nn.ReLU() 38 | elif self == Activation.lrelu: 39 | return nn.LeakyReLU(negative_slope=0.2) 40 | elif self == Activation.silu: 41 | return nn.SiLU() 42 | elif self == Activation.tanh: 43 | return nn.Tanh() 44 | else: 45 | raise NotImplementedError() 46 | 47 | 48 | def timestep_embedding(timesteps, dim, max_period=10000, dtype=torch.float32): 49 | """ 50 | Create sinusoidal timestep embeddings. 51 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 52 | These may be fractional. 53 | :param dim: the dimension of the output. 54 | :param max_period: controls the minimum frequency of the embeddings. 55 | :return: an [N x dim] Tensor of positional embeddings. 56 | """ 57 | half = dim // 2 58 | freqs = torch.exp( 59 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half 60 | ).to(device=timesteps.device) 61 | args = timesteps[:, None].type(dtype) * freqs[None] 62 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 63 | if dim % 2: 64 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 65 | return embedding 66 | 67 | 68 | class MLPLNAct(nn.Module): 69 | def __init__( 70 | self, 71 | in_channels: int, 72 | out_channels: int, 73 | norm: bool, 74 | use_cond: bool, 75 | activation: Activation, 76 | cond_channels: int, 77 | condition_bias: float = 0, 78 | dropout: float = 0, 79 | ): 80 | super().__init__() 81 | self.activation = activation 82 | self.condition_bias = condition_bias 83 | self.use_cond = use_cond 84 | 85 | self.linear = nn.Linear(in_channels, out_channels) 86 | self.act = activation.get_act() 87 | if self.use_cond: 88 | self.linear_emb = nn.Linear(cond_channels, out_channels) 89 | self.cond_layers = nn.Sequential(self.act, self.linear_emb) 90 | if norm: 91 | self.norm = nn.LayerNorm(out_channels) 92 | else: 93 | self.norm = nn.Identity() 94 | 95 | if dropout > 0: 96 | self.dropout = nn.Dropout(p=dropout) 97 | else: 98 | self.dropout = nn.Identity() 99 | 100 | self.init_weights() 101 | 102 | def init_weights(self): 103 | for module in self.modules(): 104 | if isinstance(module, nn.Linear): 105 | if self.activation == Activation.relu: 106 | init.kaiming_normal_(module.weight, 107 | a=0, 108 | nonlinearity='relu') 109 | elif self.activation == Activation.lrelu: 110 | init.kaiming_normal_(module.weight, 111 | a=0.2, 112 | nonlinearity='leaky_relu') 113 | elif self.activation == Activation.silu: 114 | init.kaiming_normal_(module.weight, 115 | a=0, 116 | nonlinearity='relu') 117 | else: 118 | # leave it as default 119 | pass 120 | 121 | def forward(self, x, cond=None): 122 | x = self.linear(x) 123 | if self.use_cond: 124 | # (n, c) or (n, c * 2) 125 | cond = self.cond_layers(cond) 126 | cond = (cond, None) 127 | 128 | # scale shift first 129 | x = x * (self.condition_bias + cond[0]) 130 | if cond[1] is not None: 131 | x = x + cond[1] 132 | # then norm 133 | x = self.norm(x) 134 | else: 135 | # no condition 136 | x = self.norm(x) 137 | x = self.act(x) 138 | x = self.dropout(x) 139 | return x 140 | 141 | 142 | class MLPSkipNet(nn.Module): 143 | """ 144 | concat x to hidden layers 145 | default MLP for the latent DPM in the paper! 146 | """ 147 | def __init__(self, 148 | num_channels=512, # this is the number of features in the input (denoising target) 149 | skip_layers='all', 150 | num_hid_channels=128, 151 | num_layers=20, 152 | num_cond_emb_channels=64, # this is the number of features in the cond embedding 153 | num_time_emb_channels=64, # this is the number of features in the time embedding 154 | activation=Activation.silu, 155 | use_norm=True, 156 | condition_bias=0, 157 | dropout=0, 158 | last_act=Activation.none, 159 | num_time_layers=2, 160 | time_last_act=False, 161 | num_cond_layers=2, 162 | cond_last_act=False, 163 | dtype='float32'): 164 | super().__init__() 165 | 166 | self.num_time_emb_channels = num_time_emb_channels 167 | self.skip_layers = skip_layers if skip_layers != "all" else list(range(num_layers)) 168 | self.dtype = DTYPES[dtype] 169 | 170 | layers = [] 171 | for i in range(num_time_layers): 172 | if i == 0: 173 | a = num_time_emb_channels 174 | b = num_channels 175 | else: 176 | a = num_channels 177 | b = num_channels 178 | layers.append(nn.Linear(a, b)) 179 | if i < num_time_layers - 1 or time_last_act: 180 | layers.append(activation.get_act()) 181 | self.time_embed = nn.Sequential(*layers) 182 | 183 | layers = [] 184 | for i in range(num_cond_layers): 185 | if i == 0: 186 | a = num_cond_emb_channels 187 | b = num_channels 188 | else: 189 | a = num_channels 190 | b = num_channels 191 | layers.append(nn.Linear(a, b)) 192 | if i < num_cond_layers - 1 or cond_last_act: 193 | layers.append(activation.get_act()) 194 | self.cond_embed = nn.Sequential(*layers) 195 | 196 | self.layers = nn.ModuleList([]) 197 | for i in range(num_layers): 198 | if i == 0: 199 | act = activation 200 | norm = use_norm 201 | cond = True 202 | a, b = num_channels, num_hid_channels 203 | dropout = dropout 204 | elif i == num_layers - 1: 205 | act = Activation.none 206 | norm = False 207 | cond = False 208 | a, b = num_hid_channels, num_channels 209 | dropout = 0 210 | else: 211 | act = activation 212 | norm = use_norm 213 | cond = True 214 | a, b = num_hid_channels, num_hid_channels 215 | dropout = dropout 216 | 217 | if i in self.skip_layers: 218 | a += num_channels 219 | 220 | self.layers.append( 221 | MLPLNAct( 222 | a, 223 | b, 224 | norm=norm, 225 | activation=act, 226 | cond_channels=2*num_channels, 227 | use_cond=cond, 228 | condition_bias=condition_bias, 229 | dropout=dropout, 230 | )) 231 | self.last_act = last_act.get_act() 232 | 233 | def forward(self, x, t, cond): 234 | t = timestep_embedding(t, self.num_time_emb_channels, dtype=self.dtype) 235 | time_emb = self.time_embed(t) 236 | cond_emb = self.cond_embed(cond) 237 | #print("AFTER", "timestep embedding", time_emb.shape, "cond", cond_emb.shape) 238 | cond = torch.cat([time_emb, cond_emb], dim=1) 239 | h = x 240 | for i in range(len(self.layers)): 241 | if i in self.skip_layers: 242 | # injecting input into the hidden layers 243 | h = torch.cat([h, x], dim=1) 244 | h = self.layers[i].forward(x=h, cond=cond) 245 | h = self.last_act(h) 246 | return h 247 | 248 | -------------------------------------------------------------------------------- /model/belfusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() -------------------------------------------------------------------------------- /model/belfusion/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim import lr_scheduler 4 | import torch.nn as nn 5 | 6 | # https://raw.githubusercontent.com/Khrylx/DLow/master/utils/torch.py 7 | 8 | tensor = torch.tensor 9 | DoubleTensor = torch.DoubleTensor 10 | FloatTensor = torch.FloatTensor 11 | LongTensor = torch.LongTensor 12 | ByteTensor = torch.ByteTensor 13 | ones = torch.ones 14 | zeros = torch.zeros 15 | 16 | class to_cpu: 17 | 18 | def __init__(self, *models): 19 | self.models = list(filter(lambda x: x is not None, models)) 20 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models] 21 | for x in self.models: 22 | x.to(torch.device('cpu')) 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, *args): 28 | for x, device in zip(self.models, self.prev_devices): 29 | x.to(device) 30 | return False 31 | 32 | 33 | class to_device: 34 | 35 | def __init__(self, device, *models): 36 | self.models = list(filter(lambda x: x is not None, models)) 37 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models] 38 | for x in self.models: 39 | x.to(device) 40 | 41 | def __enter__(self): 42 | pass 43 | 44 | def __exit__(self, *args): 45 | for x, device in zip(self.models, self.prev_devices): 46 | x.to(device) 47 | return False 48 | 49 | 50 | class to_test: 51 | 52 | def __init__(self, *models): 53 | self.models = list(filter(lambda x: x is not None, models)) 54 | self.prev_modes = [x.training for x in self.models] 55 | for x in self.models: 56 | x.train(False) 57 | 58 | def __enter__(self): 59 | pass 60 | 61 | def __exit__(self, *args): 62 | for x, mode in zip(self.models, self.prev_modes): 63 | x.train(mode) 64 | return False 65 | 66 | 67 | class to_train: 68 | 69 | def __init__(self, *models): 70 | self.models = list(filter(lambda x: x is not None, models)) 71 | self.prev_modes = [x.training for x in self.models] 72 | for x in self.models: 73 | x.train(True) 74 | 75 | def __enter__(self): 76 | pass 77 | 78 | def __exit__(self, *args): 79 | for x, mode in zip(self.models, self.prev_modes): 80 | x.train(mode) 81 | return False 82 | 83 | 84 | def batch_to(dst, *args): 85 | return [x.to(dst) if x is not None else None for x in args] 86 | 87 | 88 | def get_flat_params_from(models): 89 | if not hasattr(models, '__iter__'): 90 | models = (models, ) 91 | params = [] 92 | for model in models: 93 | for param in model.parameters(): 94 | params.append(param.data.view(-1)) 95 | 96 | flat_params = torch.cat(params) 97 | return flat_params 98 | 99 | 100 | def set_flat_params_to(model, flat_params): 101 | prev_ind = 0 102 | for param in model.parameters(): 103 | flat_size = int(np.prod(list(param.size()))) 104 | param.data.copy_( 105 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 106 | prev_ind += flat_size 107 | 108 | 109 | def get_flat_grad_from(inputs, grad_grad=False): 110 | grads = [] 111 | for param in inputs: 112 | if grad_grad: 113 | grads.append(param.grad.grad.view(-1)) 114 | else: 115 | if param.grad is None: 116 | grads.append(zeros(param.view(-1).shape)) 117 | else: 118 | grads.append(param.grad.view(-1)) 119 | 120 | flat_grad = torch.cat(grads) 121 | return flat_grad 122 | 123 | 124 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False): 125 | if create_graph: 126 | retain_graph = True 127 | 128 | inputs = list(inputs) 129 | params = [] 130 | for i, param in enumerate(inputs): 131 | if i not in filter_input_ids: 132 | params.append(param) 133 | 134 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph) 135 | 136 | j = 0 137 | out_grads = [] 138 | for i, param in enumerate(inputs): 139 | if i in filter_input_ids: 140 | out_grads.append(zeros(param.view(-1).shape)) 141 | else: 142 | out_grads.append(grads[j].view(-1)) 143 | j += 1 144 | grads = torch.cat(out_grads) 145 | 146 | for param in params: 147 | param.grad = None 148 | return grads 149 | 150 | 151 | def set_optimizer_lr(optimizer, lr): 152 | for param_group in optimizer.param_groups: 153 | param_group['lr'] = lr 154 | 155 | 156 | def filter_state_dict(state_dict, filter_keys): 157 | for key in list(state_dict.keys()): 158 | for f_key in filter_keys: 159 | if f_key in key: 160 | del state_dict[key] 161 | break 162 | 163 | 164 | def get_scheduler(optimizer, policy, nepoch_fix=None, nepoch=None, decay_step=None): 165 | if policy == 'lambda': 166 | def lambda_rule(epoch): 167 | lr_l = 1.0 - max(0, epoch - nepoch_fix) / float(nepoch - nepoch_fix + 1) 168 | return lr_l 169 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 170 | elif policy == 'step': 171 | scheduler = lr_scheduler.StepLR( 172 | optimizer, step_size=decay_step, gamma=0.1) 173 | elif policy == 'plateau': 174 | scheduler = lr_scheduler.ReduceLROnPlateau( 175 | optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 176 | else: 177 | return NotImplementedError('learning rate policy [%s] is not implemented', policy) 178 | return scheduler 179 | 180 | 181 | 182 | def init_weights(module): 183 | if isinstance(module, nn.Conv2d): 184 | nn.init.kaiming_normal_(module.weight.data, mode="fan_out") 185 | elif isinstance(module, nn.BatchNorm2d): 186 | nn.init.constant_(module.weight, 1.0) 187 | nn.init.constant_(module.bias, 0.0) 188 | elif isinstance(module, nn.Linear): 189 | # print("weights ", module) 190 | for name, param in module.named_parameters(): 191 | if "bias" in name: 192 | nn.init.constant_(param, 0.0) 193 | elif "weight" in name: 194 | nn.init.xavier_uniform_(param) 195 | elif ( 196 | isinstance(module, nn.LSTM) 197 | or isinstance(module, nn.RNN) 198 | or isinstance(module, nn.LSTMCell) 199 | or isinstance(module, nn.RNNCell) 200 | or isinstance(module, nn.GRU) 201 | or isinstance(module, nn.GRUCell) 202 | ): 203 | # https://www.cse.iitd.ac.in/~mausam/courses/col772/spring2018/lectures/12-tricks.pdf 204 | # • It can take a while for a RNN to learn to remember information 205 | # • Initialize biases for LSTM’s forget gate to 1 to remember more by default. 206 | # • Similarly, initialize biases for GRU’s reset gate to -1. 207 | DIV = 3 if isinstance(module, nn.GRU) or isinstance(module, nn.GRUCell) else 4 208 | for name, param in module.named_parameters(): 209 | if "bias" in name: 210 | #print(name) 211 | nn.init.constant_( 212 | param, 0.0 213 | ) 214 | if isinstance(module, nn.LSTMCell) \ 215 | or isinstance(module, nn.LSTM): 216 | n = param.size(0) 217 | # LSTM: (W_ii|W_if|W_ig|W_io), W_if (forget gate) => bias 1 218 | start, end = n // DIV, n // 2 219 | param.data[start:end].fill_(1.) # to remember more by default 220 | elif isinstance(module, nn.GRU) \ 221 | or isinstance(module, nn.GRUCell): 222 | # GRU: (W_ir|W_iz|W_in), W_ir (reset gate) => bias -1 223 | end = param.size(0) // DIV 224 | param.data[:end].fill_(-1.) # to remember more by default 225 | elif "weight" in name: 226 | nn.init.xavier_normal_(param) 227 | if isinstance(module, nn.LSTMCell) \ 228 | or isinstance(module, nn.LSTM) \ 229 | or isinstance(module, nn.GRU) \ 230 | or isinstance(module, nn.GRUCell): 231 | if 'weight_ih' in name: # input -> hidden weights 232 | mul = param.shape[0] // DIV 233 | for idx in range(DIV): 234 | nn.init.xavier_uniform_(param[idx * mul:(idx + 1) * mul]) 235 | elif 'weight_hh' in name: # hidden -> hidden weights (recurrent) 236 | mul = param.shape[0] // DIV 237 | for idx in range(DIV): 238 | nn.init.orthogonal_(param[idx * mul:(idx + 1) * mul]) # orthogonal initialization https://arxiv.org/pdf/1702.00071.pdf 239 | else: 240 | print(f"[WARNING] Module not initialized: {module}") 241 | -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class KLLoss(nn.Module): 8 | def __init__(self): 9 | super(KLLoss, self).__init__() 10 | 11 | def forward(self, q, p): 12 | div = torch.distributions.kl_divergence(q, p) 13 | return div.mean() 14 | 15 | def __repr__(self): 16 | return "KLLoss()" 17 | 18 | 19 | 20 | class VAELoss(nn.Module): 21 | def __init__(self, kl_p=0.0002): 22 | super(VAELoss, self).__init__() 23 | self.mse = nn.MSELoss(reduce=True, size_average=True) 24 | self.kl_loss = KLLoss() 25 | self.kl_p = kl_p 26 | 27 | def forward(self, gt_emotion, gt_3dmm, pred_emotion, pred_3dmm, distribution): 28 | rec_loss = self.mse(pred_emotion, gt_emotion) + self.mse(pred_3dmm[:,:, :52], gt_3dmm[:,:, :52]) + 10*self.mse(pred_3dmm[:,:, 52:], gt_3dmm[:,:, 52:]) 29 | mu_ref = torch.zeros_like(distribution[0].loc).to(gt_emotion.get_device()) 30 | scale_ref = torch.ones_like(distribution[0].scale).to(gt_emotion.get_device()) 31 | distribution_ref = torch.distributions.Normal(mu_ref, scale_ref) 32 | 33 | kld_loss = 0 34 | for t in range(len(distribution)): 35 | kld_loss += self.kl_loss(distribution[t], distribution_ref) 36 | kld_loss = kld_loss / len(distribution) 37 | 38 | loss = rec_loss + self.kl_p * kld_loss 39 | 40 | 41 | return loss, rec_loss, kld_loss 42 | 43 | def __repr__(self): 44 | return "VAELoss()" 45 | 46 | 47 | 48 | def div_loss(Y_1, Y_2): 49 | loss = 0.0 50 | b,t,c = Y_1.shape 51 | Y_g = torch.cat([Y_1.view(b,1,-1), Y_2.view(b,1,-1)], dim = 1) 52 | for Y in Y_g: 53 | dist = F.pdist(Y, 2) ** 2 54 | loss += (-dist / 100).exp().mean() 55 | loss /= b 56 | return loss 57 | 58 | 59 | 60 | 61 | # ================================ BeLFUSION losses ==================================== 62 | 63 | def MSELoss_AE_v2(prediction, target, target_coefficients, mu, logvar, coefficients_3dmm, 64 | w_mse=1, w_kld=1, w_coeff=1, 65 | **kwargs): 66 | # loss for autoencoder. prediction and target have shape of [batch_size, seq_length, features] 67 | assert len(prediction.shape) == len(target.shape), "prediction and target must have the same shape" 68 | assert len(prediction.shape) == 3, "Only works with predictions of shape [batch_size, seq_length, features]" 69 | batch_size = prediction.shape[0] 70 | 71 | # join last two dimensions of prediction and target 72 | prediction = prediction.reshape(prediction.shape[0], -1) 73 | target = target.reshape(target.shape[0], -1) 74 | coefficients_3dmm = coefficients_3dmm.reshape(coefficients_3dmm.shape[0], -1) 75 | target_coefficients = target_coefficients.reshape(target_coefficients.shape[0], -1) 76 | 77 | MSE = ((prediction - target) ** 2).mean() 78 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size 79 | COEFF = ((coefficients_3dmm - target_coefficients) ** 2).mean() 80 | 81 | loss_r = w_mse * MSE + w_kld * KLD + w_coeff * COEFF 82 | return {"loss": loss_r, "mse": MSE, "kld": KLD, "coeff": COEFF} 83 | 84 | 85 | def MSELoss(prediction, target, reduction="mean", **kwargs): 86 | # prediction has shape of [batch_size, num_preds, features] 87 | # target has shape of [batch_size, num_preds, features] 88 | assert len(prediction.shape) == len(target.shape), "prediction and target must have the same shape" 89 | assert len(prediction.shape) == 3, "Only works with predictions of shape [batch_size, num_preds, features]" 90 | 91 | # manual implementation of MSE loss 92 | loss = ((prediction - target) ** 2).mean(axis=-1) 93 | 94 | # reduce across multiple predictions 95 | if reduction == "mean": 96 | loss = torch.mean(loss) 97 | elif reduction == "min": 98 | loss = loss.min(axis=-1)[0].mean() 99 | else: 100 | raise NotImplementedError("reduction {} not implemented".format(reduction)) 101 | return loss 102 | 103 | 104 | def L1Loss(prediction, target, reduction="mean", **kwargs): 105 | # prediction has shape of [batch_size, num_preds, features] 106 | # target has shape of [batch_size, num_preds, features] 107 | assert len(prediction.shape) == len(target.shape), "prediction and target must have the same shape" 108 | assert len(prediction.shape) == 3, "Only works with predictions of shape [batch_size, num_preds, features]" 109 | 110 | # manual implementation of L1 loss 111 | loss = (torch.abs(prediction - target)).mean(axis=-1) 112 | 113 | # reduce across multiple predictions 114 | if reduction == "mean": 115 | loss = torch.mean(loss) 116 | elif reduction == "min": 117 | loss = loss.min(axis=-1)[0].mean() 118 | else: 119 | raise NotImplementedError("reduction {} not implemented".format(reduction)) 120 | return loss 121 | 122 | 123 | def BeLFusionLoss(prediction, target, encoded_prediction, encoded_target, 124 | losses = [L1Loss, MSELoss], 125 | losses_multipliers = [1, 1], 126 | losses_decoded = [False, True], 127 | **kwargs): 128 | # encoded_prediction has shape of [batch_size, num_preds, features] 129 | # encoded_target has shape of [batch_size, num_preds, features] 130 | # prediction has shape of [batch_size, num_preds, seq_len, features] 131 | # target has shape of [batch_size, num_preds, seq_len, features] 132 | assert len(losses) == len(losses_multipliers), "losses and losses_multipliers must have the same length" 133 | assert len(losses) == len(losses_decoded), "losses and losses_decoded must have the same length" 134 | #assert len(encoded_prediction.shape) == 3 and len(prediction.shape) == 4, "BeLFusionLoss only works with multiple predictions" 135 | 136 | if len(encoded_prediction.shape) == 2 and len(prediction.shape) == 3: # --> for the test script to work, only a single pred is used 137 | # unsqueeze the first dimension, because we only have one prediction 138 | prediction = prediction.unsqueeze(1) 139 | target = target.unsqueeze(1) 140 | encoded_prediction = encoded_prediction.unsqueeze(1) 141 | encoded_target = encoded_target.unsqueeze(1) 142 | 143 | if len(encoded_prediction.shape) == 3 and len(prediction.shape) == 4: 144 | # join last two dimensions of prediction and target 145 | prediction = prediction.reshape(prediction.shape[0], prediction.shape[1], -1) 146 | target = target.reshape(target.shape[0], target.shape[1], -1) 147 | else: 148 | raise NotImplementedError("BeLFusionLoss only works with multiple predictions") 149 | 150 | # compute losses 151 | losses_dict = {"loss": 0} 152 | for loss_name, w, decoded in zip(losses, losses_multipliers, losses_decoded): 153 | loss_final_name = loss_name + f"_{'decoded' if decoded else 'encoded'}" 154 | if decoded: 155 | losses_dict[loss_final_name] = eval(loss_name)(prediction, target, reduction="min") 156 | else: 157 | losses_dict[loss_final_name] = eval(loss_name)(encoded_prediction, encoded_target, reduction="min") 158 | 159 | losses_dict["loss"] += losses_dict[loss_final_name] * w 160 | 161 | return losses_dict 162 | 163 | -------------------------------------------------------------------------------- /regnn/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import ActionData -------------------------------------------------------------------------------- /regnn/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | from torch.utils.data import Dataset 8 | 9 | class ActionData(Dataset): 10 | def __init__(self, root, data_type, neighbors, augmentation=None, 11 | num_frames=50, stride=50, neighbor_pattern='nearest'): 12 | 13 | self.root = root 14 | self.split_data = pd.read_csv(os.path.join(root, data_type + '.csv'), header=None, delimiter=',') 15 | self.split_data = self.split_data.drop(0) 16 | 17 | speaker_path = [path for path in list(self.split_data.values[:, 1])] 18 | listener_path = [path for path in list(self.split_data.values[:, 2])] 19 | 20 | self.neighbors = neighbors # None 21 | self.aug = augmentation # None 22 | self.num_frames = num_frames # 50 23 | self.stride = stride # 25 24 | self.neighbor_pattern = neighbor_pattern # 'all' 25 | self.all_data = self.get_site_id_clip(speaker_path, listener_path, data_type) 26 | 27 | def get_site_id_clip(self, speaker_path, listener_path, data_type): 28 | all_data = [] 29 | 30 | for index in range(len(speaker_path)): 31 | site, group, pid, clip = speaker_path[index].split('/') 32 | all_data.extend([data_type, site, group, pid, clip, str(i)] 33 | for i in range(0, 750 - self.num_frames + 1, self.stride)) 34 | 35 | site, group, pid, clip = listener_path[index].split('/') 36 | all_data.extend([data_type, site, group, pid, clip, str(i)] 37 | for i in range(0, 750 - self.num_frames + 1, self.stride)) 38 | 39 | return all_data 40 | 41 | def __getitem__(self, index): 42 | dtype_site_group_pid_clip_idx = self.all_data[index] 43 | 44 | v_inputs = self.load_video_pth(dtype_site_group_pid_clip_idx) 45 | a_inputs = self.load_audio_pth(dtype_site_group_pid_clip_idx) 46 | 47 | if self.neighbor_pattern in {'nearest', 'all'}: 48 | # target: 'train+NoXI+065_2016-04-14_Nottingham+speaker+1+0' 49 | # 'dtype+site+group+pid+clip+idx' 50 | targets = '+'.join(dtype_site_group_pid_clip_idx) 51 | 52 | return v_inputs, a_inputs, targets 53 | 54 | def __len__(self): 55 | return len(self.all_data) 56 | 57 | def load_video_pth(self, dtype_site_group_pid_clip_idx): 58 | dtype, site, group, pid, clip, idx = dtype_site_group_pid_clip_idx 59 | idx = int(idx) 60 | video_pth = os.path.join(self.root, dtype, 'Video_features', site, group, pid, clip + '.pth') 61 | video_inputs = torch.load(video_pth, map_location='cpu')[idx:idx + self.num_frames] 62 | return video_inputs 63 | 64 | def load_audio_pth(self, dtype_site_group_pid_clip_idx): 65 | dtype, site, group, pid, clip, idx = dtype_site_group_pid_clip_idx 66 | idx = int(idx) 67 | audio_pth = os.path.join(self.root, dtype, 'Audio_features', site, group, pid, clip + '.pth') 68 | audio_inputs = torch.load(audio_pth, map_location='cpu')[idx:idx + self.num_frames] 69 | if audio_inputs.shape[0] != self.num_frames: 70 | audio_inputs = torch.cat([audio_inputs, audio_inputs[-1].unsqueeze(dim=0)]) 71 | return audio_inputs -------------------------------------------------------------------------------- /regnn/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from metric import * 8 | 9 | SAMPLE_NUMS = 10 10 | 11 | def parse_arg(): 12 | parser = argparse.ArgumentParser(description='PyTorch Training') 13 | # Param 14 | parser.add_argument('--data-dir', default="../data/react_clean", type=str, help="dataset path") 15 | parser.add_argument('--pred-dir', default="../data/react_clean/outputs/Gmm-logs", type=str, help="the path of saved predictions") 16 | parser.add_argument('--split', type=str, help="split of dataset", choices=["val", "test"], required=True) 17 | parser.add_argument('--threads', default=32, type=int, help="num max of threads") 18 | 19 | args = parser.parse_args() 20 | return args 21 | 22 | def replace_pid(emotion_path): 23 | if 'NoXI' in emotion_path: 24 | emotion_path=emotion_path.replace('Novice_video','P2') 25 | emotion_path=emotion_path.replace('Expert_video','P1') 26 | 27 | if 'Emotion/RECOLA/group' in emotion_path: 28 | emotion_path=emotion_path.replace('P25','P1') 29 | emotion_path=emotion_path.replace('P26','P2') 30 | emotion_path=emotion_path.replace('P41','P1') 31 | emotion_path=emotion_path.replace('P42','P2') 32 | emotion_path=emotion_path.replace('P45','P1') 33 | emotion_path=emotion_path.replace('P46','P2') 34 | 35 | return emotion_path 36 | 37 | def main(args): 38 | _list_path = pd.read_csv(os.path.join(args.data_dir, args.split + '.csv'), header=None, delimiter=',') 39 | _list_path = _list_path.drop(0) 40 | 41 | speaker_path_list = [path for path in list(_list_path.values[:, 1])] + [path for path in list(_list_path.values[:, 2])] 42 | listener_path_list = [path for path in list(_list_path.values[:, 2])] + [path for path in list(_list_path.values[:, 1])] 43 | 44 | listener_emotion_gt_list = [] 45 | listener_emotion_pred_list = [] 46 | speaker_emotion_list = [] 47 | 48 | for index in range(len(speaker_path_list)): 49 | # speaker emotion 50 | speaker_path = speaker_path_list[index] 51 | speaker_emotion_path = os.path.join(args.data_dir, args.split, 'Emotion', speaker_path+'.csv') 52 | 53 | speaker_emotion_path = replace_pid(speaker_emotion_path) 54 | speaker_emotion = pd.read_csv(speaker_emotion_path, header=None, delimiter=',') 55 | speaker_emotion = torch.from_numpy(np.array(speaker_emotion.drop(0)).astype(np.float32)) 56 | 57 | speaker_emotion_list.append(speaker_emotion) 58 | 59 | # listener emotion 60 | listener_path = listener_path_list[index] 61 | listener_emotion_path = os.path.join(args.data_dir, args.split, 'Emotion', listener_path+'.csv') 62 | 63 | listener_emotion_path = replace_pid(listener_emotion_path) 64 | listener_emotion = pd.read_csv(listener_emotion_path, header=None, delimiter=',') 65 | listener_emotion = torch.from_numpy(np.array(listener_emotion.drop(0)).astype(np.float32)) 66 | 67 | listener_emotion_gt_list.append(listener_emotion) 68 | 69 | # predicted listener's emotions 70 | listener_emotion_pred = [] 71 | for j in range(SAMPLE_NUMS): 72 | pred_path = os.path.join(args.pred_dir, args.split, speaker_path, 'result-' + str(j) + '.pth') 73 | listener_emotion_pred.append(torch.load(pred_path).cpu()) 74 | 75 | listener_emotion_pred = torch.stack(listener_emotion_pred) 76 | 77 | listener_emotion_pred_list.append(listener_emotion_pred) 78 | 79 | speaker_emotion_gt = torch.stack(speaker_emotion_list, dim = 0) 80 | listener_emotion_gt = torch.stack(listener_emotion_gt_list, dim = 0) 81 | all_listener_emotion_pred = torch.stack(listener_emotion_pred_list) 82 | 83 | print("-----------------Evaluating Metric-----------------") 84 | 85 | p = args.threads 86 | 87 | # If you have problems running function compute_TLCC_mp, please replace this function with function compute_TLCC 88 | TLCC = compute_TLCC_mp(all_listener_emotion_pred, speaker_emotion_gt, p=p) 89 | 90 | # If you have problems running function compute_FRC_mp, please replace this function with function compute_FRC 91 | FRC = compute_FRC_mp(args.data_dir, all_listener_emotion_pred, listener_emotion_gt, val_test=args.split, p=p) 92 | 93 | # If you have problems running function compute_FRD_mp, please replace this function with function compute_FRD 94 | FRD = compute_FRD_mp(args.data_dir, all_listener_emotion_pred, listener_emotion_gt, val_test=args.split, p=p) 95 | 96 | FRDvs = compute_FRDvs(all_listener_emotion_pred) 97 | FRVar = compute_FRVar(all_listener_emotion_pred) 98 | smse = compute_s_mse(all_listener_emotion_pred) 99 | 100 | print("Metric: | FRC: {:.5f} | FRD: {:.5f} | S-MSE: {:.5f} | FRVar: {:.5f} | FRDvs: {:.5f} | TLCC: {:.5f}".format(FRC, FRD, smse, FRVar, FRDvs, TLCC)) 101 | print("Latex-friendly --> model_name & {:.2f} & {:.2f} & {:.4f} & {:.4f} & {:.4f} & - & {:.2f} \\\\".format( FRC, FRD, smse, FRVar, FRDvs, TLCC)) 102 | 103 | if __name__=="__main__": 104 | args = parse_arg() 105 | os.environ["NUMEXPR_MAX_THREADS"] = '32' 106 | main(args) -------------------------------------------------------------------------------- /regnn/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from models import SwinTransformer, VGGish 5 | import pandas as pd 6 | import os 7 | from PIL import Image 8 | 9 | from torch.utils import data 10 | from torchvision import transforms 11 | 12 | import numpy as np 13 | import random 14 | import pandas as pd 15 | from PIL import Image 16 | 17 | from decord import VideoReader 18 | from decord import cpu 19 | 20 | import argparse 21 | 22 | def parse_arg(): 23 | parser = argparse.ArgumentParser(description='PyTorch Training') 24 | # Param 25 | parser.add_argument('--data-dir', default="../data/react/cropped_face", type=str, help="dataset path") 26 | parser.add_argument('--save-dir', default="../data/react_clean", type=str, help="the dir to save features") 27 | parser.add_argument('--split', type=str, help="split of dataset", choices=["train", "val", "test"], required=True) 28 | parser.add_argument('--type', type=str, help="type of features to extract", choices=["audio", "video"], required=True) 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | class Transform(object): 34 | def __init__(self, img_size=256, crop_size=224): 35 | self.img_size = img_size 36 | self.crop_size = crop_size 37 | 38 | def __call__(self, img): 39 | 40 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 41 | 42 | transform = transforms.Compose([ 43 | transforms.Resize(self.img_size), 44 | transforms.CenterCrop(self.crop_size), 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | 49 | img = transform(img) 50 | return img 51 | 52 | def extract_audio_features(args): 53 | model = VGGish(preprocess=True) 54 | model = model.cuda() 55 | model.eval() 56 | 57 | _list_path = pd.read_csv(os.path.join(args.data_dir, args.split + '.csv'), header=None, delimiter=',') 58 | _list_path = _list_path.drop(0) 59 | 60 | all_path = [path for path in list(_list_path.values[:, 1])] + [path for path in list(_list_path.values[:, 2])] 61 | 62 | for path in all_path: 63 | ab_audio_path = os.path.join(args.data_dir, args.split, 'Audio_files', path+'.wav') 64 | 65 | with torch.no_grad(): 66 | audio_features = model.forward(ab_audio_path, fs=25).cpu() 67 | 68 | site, group, pid, clip = path.split('/') 69 | if not os.path.exists(os.path.join(args.save_dir, args.split, 'Audio_features', site, group, pid)): 70 | os.makedirs(os.path.join(args.save_dir, args.split, 'Audio_features', site, group, pid)) 71 | 72 | torch.save(audio_features, os.path.join(args.save_dir, args.split, 'Audio_features', path+'.pth')) 73 | 74 | def extract_video_features(args): 75 | _transform = Transform(img_size=256, crop_size=224) 76 | 77 | model = SwinTransformer(embed_dim = 96, depths = [2, 2, 6, 2], num_heads = [3, 6, 12, 24], window_size = 7, drop_path_rate = 0.2, num_classes=7) 78 | # Load the weights of pre-trained SwinTransformer 79 | model.load_state_dict(torch.load(r"/scratch/recface/hz204/react_data/pretrained/swin_fer.pth", map_location='cpu')) 80 | model = model.cuda() 81 | model.eval() 82 | 83 | _list_path = pd.read_csv(os.path.join(args.data_dir, args.split + '.csv'), header=None, delimiter=',') 84 | _list_path = _list_path.drop(0) 85 | 86 | all_path = [path for path in list(_list_path.values[:, 1])] + [path for path in list(_list_path.values[:, 2])] 87 | 88 | total_length = 751 89 | 90 | for path in all_path: 91 | clip = [] 92 | ab_video_path = os.path.join(args.data_dir, path+'.mp4') 93 | with open(ab_video_path, 'rb') as f: 94 | vr = VideoReader(f, ctx=cpu(0)) 95 | for i in range(total_length): 96 | frame = vr[i] 97 | img=Image.fromarray(frame.asnumpy()) 98 | img = _transform(img) 99 | clip.append(img.unsqueeze(0)) 100 | 101 | video_clip = torch.cat(clip, dim=0).cuda() 102 | with torch.no_grad(): 103 | video_features = model.forward_features(video_clip).cpu() 104 | 105 | site, group, pid, clip = path.split('/') 106 | if not os.path.exists(os.path.join(args.save_dir, args.split, 'Video_features', site, group, pid)): 107 | os.makedirs(os.path.join(args.save_dir, args.split, 'Video_features', site, group, pid)) 108 | 109 | torch.save(video_features, os.path.join(args.save_dir, args.split, 'Video_features', path+'.pth')) 110 | 111 | def main(args): 112 | if args.type == 'video': 113 | extract_video_features(args) 114 | elif args.type == 'audio': 115 | extract_audio_features(args) 116 | 117 | # --------------------------------------------------------------------------------- 118 | 119 | 120 | if __name__=="__main__": 121 | args = parse_arg() 122 | os.environ["NUMEXPR_MAX_THREADS"] = '32' 123 | main(args) 124 | -------------------------------------------------------------------------------- /regnn/metric/ACC.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.distance import pdist 2 | import numpy as np 3 | import torch 4 | import os 5 | -------------------------------------------------------------------------------- /regnn/metric/FRC.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import pdist 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | from tslearn.metrics import dtw 7 | from functools import partial 8 | import multiprocessing as mp 9 | import warnings 10 | 11 | 12 | def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue, *, 13 | dtype=None): 14 | 15 | if bias is not np._NoValue or ddof is not np._NoValue: 16 | # 2015-03-15, 1.10 17 | warnings.warn('bias and ddof have no effect and are deprecated', 18 | DeprecationWarning, stacklevel=3) 19 | c = np.cov(x, y, rowvar, dtype=dtype) 20 | 21 | try: 22 | d = np.diag(c) 23 | except ValueError: 24 | # scalar covariance 25 | # nan if incorrect value (nan, inf, 0), 1 otherwise 26 | return c / c 27 | stddev = np.sqrt(d.real) 28 | c /= stddev[:, None] 29 | c /= stddev[None, :] 30 | c = np.nan_to_num(c) 31 | 32 | # Clip real and imaginary parts to [-1, 1]. This does not guarantee 33 | # abs(a[i,j]) <= 1 for complex arrays, but is the best we can do without 34 | # excessive work. 35 | np.clip(c.real, -1, 1, out=c.real) 36 | if np.iscomplexobj(c): 37 | np.clip(c.imag, -1, 1, out=c.imag) 38 | return c 39 | 40 | def concordance_correlation_coefficient(y_true, y_pred, 41 | sample_weight=None, 42 | multioutput='uniform_average'): 43 | """Concordance correlation coefficient. 44 | The concordance correlation coefficient is a measure of inter-rater agreement. 45 | It measures the deviation of the relationship between predicted and true values 46 | from the 45 degree angle. 47 | Read more: https://en.wikipedia.org/wiki/Concordance_correlation_coefficient 48 | Original paper: Lawrence, I., and Kuei Lin. "A concordance correlation coefficient to evaluate reproducibility." Biometrics (1989): 255-268. 49 | Parameters 50 | ---------- 51 | y_true : array-like of shape = (n_samples) or (n_samples, n_outputs) 52 | Ground truth (correct) target values. 53 | y_pred : array-like of shape = (n_samples) or (n_samples, n_outputs) 54 | Estimated target values. 55 | Returns 56 | ------- 57 | loss : A float in the range [-1,1]. A value of 1 indicates perfect agreement 58 | between the true and the predicted values. 59 | Examples 60 | -------- 61 | >>> from sklearn.metrics import concordance_correlation_coefficient 62 | >>> y_true = [3, -0.5, 2, 7] 63 | >>> y_pred = [2.5, 0.0, 2, 8] 64 | >>> concordance_correlation_coefficient(y_true, y_pred) 65 | 0.97678916827853024 66 | """ 67 | if len(y_true.shape) >1: 68 | ccc_list = [] 69 | for i in range(y_true.shape[1]): 70 | cor = corrcoef(y_true[:,i], y_pred[:, i])[0][1] 71 | mean_true = np.mean(y_true[:,i]) 72 | 73 | mean_pred = np.mean(y_pred[:,i]) 74 | 75 | 76 | var_true = np.var(y_true[:,i]) 77 | var_pred = np.var(y_pred[:,i]) 78 | 79 | sd_true = np.std(y_true[:,i]) 80 | sd_pred = np.std(y_pred[:,i]) 81 | 82 | numerator = 2 * cor * sd_true * sd_pred 83 | 84 | denominator = var_true + var_pred + (mean_true - mean_pred) ** 2 85 | 86 | ccc = numerator / (denominator + 1e-8) 87 | 88 | ccc_list.append(ccc) 89 | ccc = np.mean(ccc_list) 90 | else: 91 | cor = np.corrcoef(y_true, y_pred)[0][1] 92 | mean_true = np.mean(y_true) 93 | mean_pred = np.mean(y_pred) 94 | 95 | var_true = np.var(y_true) 96 | var_pred = np.var(y_pred) 97 | 98 | sd_true = np.std(y_true) 99 | sd_pred = np.std(y_pred) 100 | 101 | numerator = 2 * cor * sd_true * sd_pred 102 | 103 | denominator = var_true + var_pred + (mean_true - mean_pred) ** 2 104 | ccc = numerator / (denominator + 1e-8) 105 | return ccc 106 | 107 | 108 | 109 | 110 | 111 | def compute_FRC(args, pred, listener_em, val_test='val'): 112 | pred = pred 113 | listener_em = listener_em 114 | if val_test == 'val': 115 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_val.npy')) 116 | else: 117 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_test.npy')) 118 | 119 | all_FRC_list = [] 120 | for i in range(pred.shape[1]): 121 | FRC_list = [] 122 | for k in range(pred.shape[0]): 123 | speaker_neighbour_index = np.argwhere(speaker_neighbour_matrix[k] == 1).reshape(-1) 124 | speaker_neighbour_index_len = len(speaker_neighbour_index) 125 | ccc_list = [] 126 | for n_index in range(speaker_neighbour_index_len): 127 | 128 | ''' 129 | listener_em order :[listener1, listener2, listener3, ....., listener_n, speaker1, speaker2, speaker3, ....., speaker_n] 130 | listener1: [1, emotion_dim] 131 | 132 | listener_em[speaker_neighbour_index[n_index]]: 133 | 1. speaker_neighbour_index[n_index]: speaker_j (with similar emotion as the speaker_k) 134 | 2. listener_em[speaker_neighbour_index[n_index]]: emotion features of listener_j (speaker_j -> listener_j) 135 | So we can get an additional GT listener_j to listener_k (i.e., speaker_j -> listener_k) 136 | ''' 137 | 138 | similar_listener_emotion = listener_em[speaker_neighbour_index[n_index]] 139 | ccc = concordance_correlation_coefficient(similar_listener_emotion.numpy(), pred[k,i].numpy()) 140 | ccc_list.append(ccc) 141 | max_ccc = max(ccc_list) 142 | FRC_list.append(max_ccc) 143 | all_FRC_list.append(np.mean(FRC_list)) 144 | return sum(all_FRC_list) 145 | 146 | 147 | 148 | def _func(k_neighbour_matrix, k_pred, em=None): 149 | neighbour_index = np.argwhere(k_neighbour_matrix == 1).reshape(-1) 150 | neighbour_index_len = len(neighbour_index) 151 | max_ccc_sum = 0 152 | for i in range(k_pred.shape[0]): 153 | ccc_list = [] 154 | for n_index in range(neighbour_index_len): 155 | emotion = em[neighbour_index[n_index]] 156 | ccc = concordance_correlation_coefficient(emotion, k_pred[i]) 157 | ccc_list.append(ccc) 158 | max_ccc_sum += max(ccc_list) 159 | return max_ccc_sum 160 | 161 | 162 | 163 | def compute_FRC_mp(data_dir, pred, em, val_test='val', p=1): 164 | # pred: N 10 750 25 165 | # speaker: N 750 25 166 | 167 | if val_test == 'val': 168 | neighbour_matrix = np.load(os.path.join(data_dir, 'neighbour_emotion_val.npy')) 169 | else: 170 | neighbour_matrix = np.load(os.path.join(data_dir, 'neighbour_emotion_test.npy')) 171 | 172 | FRC_list = [] 173 | with mp.Pool(processes=p) as pool: 174 | # use map 175 | _func_partial = partial(_func, em=em.numpy()) 176 | FRC_list += pool.starmap(_func_partial, zip(neighbour_matrix, pred.numpy())) 177 | return np.mean(FRC_list) 178 | -------------------------------------------------------------------------------- /regnn/metric/FRD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tslearn.metrics import dtw 4 | from functools import partial 5 | import multiprocessing as mp 6 | 7 | 8 | def _func(k_neighbour_matrix, k_pred, em=None): 9 | neighbour_index = np.argwhere(k_neighbour_matrix == 1).reshape(-1) 10 | neighbour_index_len = len(neighbour_index) 11 | min_dwt_sum = 0 12 | for i in range(k_pred.shape[0]): 13 | dwt_list = [] 14 | for n_index in range(neighbour_index_len): 15 | emotion = em[neighbour_index[n_index]] 16 | res = 0 17 | for st, ed, weight in [(0, 15, 1 / 15), (15, 17, 1), (17, 25, 1 / 8)]: 18 | res += weight * dtw(k_pred[i].astype(np.float32)[:, st: ed], emotion.astype(np.float32)[:, st: ed]) 19 | dwt_list.append(res) 20 | min_dwt_sum += min(dwt_list) 21 | return min_dwt_sum 22 | 23 | 24 | def compute_FRD_mp(data_dir, pred, em, val_test='val', p=4): 25 | # pred: N 10 750 25 26 | # speaker: N 750 25 27 | 28 | if val_test == 'val': 29 | neighbour_matrix = np.load(os.path.join(data_dir, 'neighbour_emotion_val.npy')) 30 | else: 31 | neighbour_matrix = np.load(os.path.join(data_dir, 'neighbour_emotion_test.npy')) 32 | 33 | FRD_list = [] 34 | with mp.Pool(processes=p) as pool: 35 | # use map 36 | _func_partial = partial(_func, em=em.numpy()) 37 | FRD_list += pool.starmap(_func_partial, zip(neighbour_matrix, pred.numpy())) 38 | 39 | return np.mean(FRD_list) 40 | 41 | 42 | def compute_FRD(args, pred, listener_em, val_test='val'): 43 | # pred: N 10 750 25 44 | # speaker: N 750 25 45 | 46 | if val_test == 'val': 47 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_val.npy')) 48 | else: 49 | speaker_neighbour_matrix = np.load(os.path.join(args.dataset_path, 'neighbour_emotion_test.npy')) 50 | all_FRD_list = [] 51 | for i in range(pred.shape[1]): 52 | FRD_list = [] 53 | for k in range(pred.shape[0]): 54 | speaker_neighbour_index = np.argwhere(speaker_neighbour_matrix[k] == 1).reshape(-1) 55 | speaker_neighbour_index_len = len(speaker_neighbour_index) 56 | dwt_list = [] 57 | for n_index in range(speaker_neighbour_index_len): 58 | emotion = listener_em[speaker_neighbour_index[n_index]] 59 | res = 0 60 | for st, ed, weight in [(0, 15, 1 / 15), (15, 17, 1), (17, 25, 1 / 8)]: 61 | res += weight * dtw(pred[k, i].numpy().astype(np.float32)[:, st: ed], 62 | emotion.numpy().astype(np.float32)[:, st: ed]) 63 | dwt_list.append(res) 64 | min_dwt = min(dwt_list) 65 | FRD_list.append(min_dwt) 66 | all_FRD_list.append(np.mean(FRD_list)) 67 | return sum(all_FRD_list) 68 | -------------------------------------------------------------------------------- /regnn/metric/FRDvs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_FRDvs(preds): 5 | # preds: (N, 10, 750, 25) 6 | preds_ = preds.reshape(preds.shape[0], preds.shape[1], -1) 7 | preds_ = preds_.transpose(0, 1) 8 | # preds_: (10, N, 750*25) 9 | dist = torch.pow(torch.cdist(preds_, preds_), 2) 10 | # dist: (10, N, N) 11 | dist = torch.sum(dist) / (preds.shape[0] * (preds.shape[0] - 1) * preds.shape[1]) 12 | return dist / preds_.shape[-1] 13 | -------------------------------------------------------------------------------- /regnn/metric/FRVar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_FRVar(preds): 4 | if len(preds.shape) == 3: 5 | # preds: (10, 750, 25) 6 | var = torch.var(preds, dim=1) 7 | return torch.mean(var) 8 | elif len(preds.shape) == 4: 9 | # preds: (N, 10, 750, 25) 10 | var = torch.var(preds, dim=2) 11 | return torch.mean(var) 12 | -------------------------------------------------------------------------------- /regnn/metric/S_MSE.py: -------------------------------------------------------------------------------- 1 | # from scipy.spatial.distance import pdist 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | 7 | def compute_s_mse(preds): 8 | # preds: (B, 10, 750, 25) 9 | dist = 0 10 | for b in range(preds.shape[0]): 11 | preds_item = preds[b] 12 | if preds_item.shape[0] == 1: 13 | return 0.0 14 | preds_item_ = preds_item.reshape(preds_item.shape[0], -1) 15 | dist_ = torch.pow(torch.cdist(preds_item_, preds_item_), 2) 16 | dist_ = torch.sum(dist_) / (preds_item.shape[0] * (preds_item.shape[0] - 1) * preds_item_.shape[1]) 17 | dist += dist_ 18 | return dist / preds.shape[0] -------------------------------------------------------------------------------- /regnn/metric/TLCC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import multiprocessing as mp 4 | 5 | def crosscorr(datax, datay, lag=0, dim=25): 6 | pcc_list = [] 7 | for i in range(dim): 8 | cn_1, cn_2 = shift(datax[:, i], datay[:, i], lag) 9 | pcc_i = np.corrcoef(cn_1, cn_2)[0, 1] 10 | # pcc_i = torch.corrcoef(torch.stack([cn_1, cn_2], dim=0).float())[0, 1] 11 | pcc_list.append(pcc_i.item()) 12 | return torch.mean(torch.Tensor(pcc_list)) 13 | 14 | 15 | def calculate_tlcc(pred, sp, seconds=2, fps=25): 16 | rs = [crosscorr(pred, sp, lag, sp.shape[-1]) for lag in range(-int(seconds * fps - 1), int(seconds * fps))] 17 | peak = max(rs) 18 | center = rs[len(rs) // 2] 19 | offset = len(rs) // 2 - torch.argmax(torch.Tensor(rs)) 20 | return peak, center, offset 21 | 22 | def compute_TLCC(pred, speaker): 23 | # pred: N 10 750 25 24 | # speaker: N 750 25 25 | offset_list = [] 26 | for k in range(speaker.shape[0]): 27 | pred_item = pred[k] 28 | sp_item = speaker[k] 29 | for i in range(pred_item.shape[0]): 30 | peak, center, offset = calculate_tlcc(pred_item[i].float().numpy(), sp_item.float().numpy()) 31 | offset_list.append(torch.abs(offset).item()) 32 | return torch.mean(torch.Tensor(offset_list)).item() 33 | 34 | 35 | def _func(pred_item, sp_item): 36 | for i in range(pred_item.shape[0]): 37 | peak, center, offset = calculate_tlcc(pred_item[i], sp_item) 38 | return torch.abs(offset).item() 39 | 40 | def compute_TLCC_mp(pred, speaker, p=8): 41 | # pred: N 10 750 25 42 | # speaker: N 750 25 43 | offset_list = [] 44 | # process each speaker in parallel 45 | np.seterr(divide='ignore', invalid='ignore') 46 | 47 | with mp.Pool(processes=p) as pool: 48 | # use map 49 | offset_list += pool.starmap(_func, zip(pred.float().numpy(), speaker.float().numpy())) 50 | return torch.mean(torch.Tensor(offset_list)).item() 51 | 52 | 53 | def SingleTLCC(pred, speaker): 54 | # pred: 10 750 25 55 | # speaker: 750 25 56 | offset_list = [] 57 | for i in range(pred.shape[0]): 58 | peak, center, offset = calculate_tlcc(pred[i].float(), speaker.float()) 59 | offset_list.append(torch.abs(offset).item()) 60 | return torch.mean(torch.Tensor(offset_list)).item() 61 | 62 | 63 | def shift(x, y, lag): 64 | if lag > 0: 65 | return x[lag:], y[:-lag] 66 | elif lag < 0: 67 | return x[:lag], y[-lag:] 68 | else: 69 | return x, y 70 | -------------------------------------------------------------------------------- /regnn/metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .FRC import compute_FRC_mp, compute_FRC 2 | from .FRD import compute_FRD_mp, compute_FRD 3 | from .FRDvs import compute_FRDvs 4 | from .FRVar import compute_FRVar 5 | from .S_MSE import compute_s_mse 6 | from .TLCC import compute_TLCC, compute_TLCC_mp 7 | -------------------------------------------------------------------------------- /regnn/metric/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def s_mse(preds): 5 | # preds: (B, 10, 750, 25) 6 | dist = 0 7 | for b in range(preds.shape[0]): 8 | preds_item = preds[b] 9 | if preds_item.shape[0] == 1: 10 | return 0.0 11 | preds_item_ = preds_item.reshape(preds_item.shape[0], -1) 12 | dist_ = torch.pow(torch.cdist(preds_item_, preds_item_), 2) 13 | dist_ = torch.sum(dist_) / (preds_item.shape[0] * (preds_item.shape[0] - 1) * preds_item_.shape[1]) 14 | dist += dist_ 15 | return dist / preds.shape[0] 16 | 17 | 18 | def FRVar(preds): 19 | if len(preds.shape) == 3: 20 | # preds: (10, 750, 25) 21 | var = torch.var(preds, dim=1) 22 | return torch.mean(var) 23 | elif len(preds.shape) == 4: 24 | # preds: (N, 10, 750, 25) 25 | var = torch.var(preds, dim=2) 26 | return torch.mean(var) 27 | 28 | 29 | 30 | def FRDvs(preds): 31 | # preds: (N, 10, 750, 25) 32 | preds_ = preds.reshape(preds.shape[0], preds.shape[1], -1) 33 | preds_ = preds_.transpose(0, 1) 34 | # preds_: (10, N, 750*25) 35 | dist = torch.pow(torch.cdist(preds_, preds_), 2) 36 | # dist: (10, N, N) 37 | dist = torch.sum(dist) / (preds.shape[0] * (preds.shape[0] - 1) * preds.shape[1]) 38 | return dist / preds_.shape[-1] 39 | 40 | 41 | from scipy.spatial.distance import pdist 42 | import numpy as np 43 | import torch 44 | import os 45 | 46 | def compute_FRVar(pred): 47 | FRVar_list = [] 48 | for k in range(pred.shape[0]): 49 | pred_item = pred[k] 50 | for i in range(0, pred_item.shape[0]): 51 | var = np.mean(np.var(pred_item[i].numpy().astype(np.float32), axis=0)) 52 | FRVar_list.append(var) 53 | return np.mean(FRVar_list) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /regnn/models/MULTModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from .multimodel.transformer import TransformerEncoder 6 | 7 | class MULTModel(nn.Module): 8 | def __init__(self): 9 | """ 10 | Construct a MulT model. 11 | """ 12 | super(MULTModel, self).__init__() 13 | self.orig_d_a, self.orig_d_v =128, 768 14 | self.d_a, self.d_v = 128, 128 15 | self.vonly = True 16 | self.num_heads = 4 17 | self.layers = 5 18 | self.attn_dropout = 0.1 19 | self.attn_dropout_a = 0.0 20 | self.attn_dropout_v = 0.0 21 | self.relu_dropout = 0.1 22 | self.res_dropout = 0.1 23 | self.out_dropout = 0.0 24 | self.embed_dropout = 0.25 25 | self.attn_mask = False 26 | 27 | combined_dim = self.d_v # assuming d_l == d_a == d_v 28 | 29 | output_dim = 64 # This is actually not a hyperparameter :-) 30 | 31 | # 1. Temporal convolutional layers 32 | self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=1, padding=0, bias=False) 33 | self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=1, padding=0, bias=False) 34 | 35 | # 2. Crossmodal Attentions 36 | if self.vonly: 37 | self.trans_v_with_a = self.get_network(self_type='va') 38 | 39 | # 3. Self Attentions (Could be replaced by LSTMs, GRUs, etc.) 40 | # [e.g., self.trans_x_mem = nn.LSTM(self.d_x, self.d_x, 1) 41 | # self.trans_a_mem = self.get_network(self_type='a_mem', layers=3) 42 | self.trans_v_mem = self.get_network(self_type='av', layers=3) 43 | 44 | # Projection layers 45 | self.proj1 = nn.Linear(combined_dim, combined_dim) 46 | self.proj2 = nn.Linear(combined_dim, combined_dim) 47 | self.out_layer = nn.Linear(combined_dim, output_dim) 48 | 49 | def get_network(self, self_type='l', layers=-1): 50 | if self_type in ['l', 'al', 'vl']: 51 | embed_dim, attn_dropout = self.d_l, self.attn_dropout 52 | elif self_type in ['a', 'la', 'va']: 53 | embed_dim, attn_dropout = self.d_a, self.attn_dropout_a 54 | elif self_type in ['v', 'lv', 'av']: 55 | embed_dim, attn_dropout = self.d_v, self.attn_dropout_v 56 | elif self_type == 'l_mem': 57 | embed_dim, attn_dropout = 2 * self.d_l, self.attn_dropout 58 | elif self_type == 'a_mem': 59 | embed_dim, attn_dropout = 2 * self.d_a, self.attn_dropout 60 | elif self_type == 'v_mem': 61 | embed_dim, attn_dropout = 2 * self.d_v, self.attn_dropout 62 | else: 63 | raise ValueError("Unknown network type") 64 | 65 | return TransformerEncoder(embed_dim=embed_dim, 66 | num_heads=self.num_heads, 67 | layers=max(self.layers, layers), 68 | attn_dropout=attn_dropout, 69 | relu_dropout=self.relu_dropout, 70 | res_dropout=self.res_dropout, 71 | embed_dropout=self.embed_dropout, 72 | attn_mask=self.attn_mask) 73 | 74 | def forward(self, x_v, x_a): 75 | """ 76 | text, audio, and vision should have dimension [batch_size, seq_len, n_features] 77 | """ 78 | # B N C 79 | x_a = x_a.transpose(1, 2) 80 | x_v = x_v.transpose(1, 2) 81 | 82 | # Project the textual/visual/audio features 83 | proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a) 84 | proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v) 85 | proj_x_a = proj_x_a.permute(2, 0, 1) 86 | proj_x_v = proj_x_v.permute(2, 0, 1) 87 | # N B C 88 | if self.vonly: 89 | # (L,A) --> V 90 | h_v_with_as = self.trans_v_with_a(proj_x_v, proj_x_a, proj_x_a) 91 | h_vs = h_v_with_as 92 | h_vs = self.trans_v_mem(h_vs) 93 | if type(h_vs) == tuple: 94 | h_vs = h_vs[0] 95 | last_hs = h_vs 96 | 97 | # last_hs = torch.cat([proj_x_v, last_h_v], dim=1) 98 | 99 | # A residual block 100 | last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training)) 101 | last_hs_proj += last_hs 102 | 103 | output = self.out_layer(last_hs_proj) 104 | return output 105 | 106 | -------------------------------------------------------------------------------- /regnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .LipschitzGraph import LipschitzGraph 3 | from .swin_transformer import SwinTransformer 4 | from .torchvggish.vggish import VGGish 5 | from .cognitive import CognitiveProcessor 6 | from .perceptual import PercepProcessor 7 | from .mhp import MHP 8 | 9 | from .MULTModel import MULTModel -------------------------------------------------------------------------------- /regnn/models/cognitive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from .distribution import GmmDistribution, MultiGaussian, Gaussian 5 | from .model_utils import Mlp, MultiNodeMlp 6 | 7 | 8 | class EdgeLayer(nn.Module): 9 | def __init__(self, dim, mlp_before=False, n_channels=8, bias=False, neighbors=5): 10 | super().__init__() 11 | self.n_channels = n_channels 12 | self.neighbors = neighbors 13 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 14 | self.scale = dim ** -0.5 15 | if not mlp_before: 16 | self.qk = nn.Linear(dim, dim * 2 * self.n_channels, bias=bias) 17 | else: 18 | self.qk = nn.Sequential(Mlp(dim, dim), 19 | nn.Linear(dim, dim * 2 * self.n_channels, bias=bias) 20 | ) 21 | 22 | def forward(self, x): 23 | B, N, C = x.shape 24 | # B, N, 2, P, C 25 | # | 26 | # 2, B, P, N, C 27 | qk = self.qk(x).reshape(B, N, 2, self.n_channels, C).permute(2, 0, 3, 1, 4) 28 | q, k = qk[0], qk[1] 29 | # B, P, N, N 30 | attn = (q @ k.transpose(-2, -1)) * self.scale 31 | attn = attn.softmax(dim=-1) 32 | edge = self.clip(attn) 33 | 34 | return edge 35 | 36 | def clip(self, edge, norm=True): 37 | _edge = edge.detach().clone() 38 | sum_edge = torch.sum(_edge, dim=1) 39 | value, index = torch.topk(sum_edge, self.neighbors, dim=-1) 40 | masks = [] 41 | for i in range(index.shape[0]): 42 | inds = index[i] 43 | mask = torch.eye(sum_edge.shape[-1]) 44 | for j, ind in enumerate(inds): 45 | for ind_ in inds: 46 | mask[j][ind_] = 1. 47 | 48 | masks.append(mask) 49 | 50 | masks = torch.stack(masks, dim=0).to(edge.device) 51 | 52 | new_edge = masks.unsqueeze(1) * edge 53 | if norm: 54 | new_edge = self.norm_edge(new_edge) 55 | return new_edge 56 | 57 | def norm_edge(self, edge): 58 | # edge B P N N 59 | norm_row_edge = edge / (torch.sum(edge, dim=-1, keepdims=True) + 1e-16) 60 | norm_col_edge = norm_row_edge / (torch.sum(norm_row_edge, dim=-2, keepdims=True) + 1e-16) 61 | normed_edge = torch.matmul(norm_row_edge, norm_col_edge.transpose(-1, -2)) 62 | return normed_edge 63 | 64 | 65 | class CognitiveProcessor(nn.Module): 66 | def __init__(self, n_elements=25, input_dim=64, convert_type='indirect', mlp_in_edge=True, n_channels=8, 67 | num_features=50, k=6, multi_var=False, differ_muk=False, muk_func=False, dis_type='Gmm'): 68 | super().__init__() 69 | if convert_type == 'indirect': 70 | self.convert_layer = MultiNodeMlp(in_dim=input_dim, out_dim=n_elements, n_nodes=50) 71 | # for i in range(n_elements): 72 | # FC_layers = nn.Sequential(nn.Linear(input_dim, input_dim // 2), 73 | # nn.Linear(input_dim // 2, 1) 74 | # ) 75 | # self.__setattr__(name='fc_' + str(i), value=FC_layers) 76 | elif convert_type == 'direct': 77 | self.convert_layer = Mlp(in_features=input_dim, hidden_features=input_dim // 2, out_features=n_elements) 78 | elif convert_type == 'none': 79 | self.convert_layer = nn.Identity() 80 | else: 81 | raise KeyError('Invalid Convert Type') 82 | 83 | self.dis_type = dis_type 84 | if dis_type == 'Gmm': 85 | self.dll = GmmDistribution(frame_size=num_features, k=k, multi_var=multi_var, 86 | differ_muk=differ_muk, muk_func=muk_func) 87 | elif dis_type == 'MultiGaussian': 88 | self.dll = MultiGaussian(frame_size=num_features, n_elements=n_elements) 89 | 90 | elif dis_type == 'Gaussian': 91 | self.dll = Gaussian(frame_size=50, n_elements=n_elements) 92 | 93 | else: 94 | self.dll = None 95 | self.EdgeLayer = EdgeLayer(dim=num_features, mlp_before=mlp_in_edge, n_channels=n_channels) 96 | 97 | def forward(self, inputs): 98 | # inputs: B C T 99 | inputs = self.convert_layer(inputs) 100 | outputs = inputs.transpose(-1, -2) 101 | # print('OUTPUTS', outputs.shape) 102 | edge = self.EdgeLayer(outputs) 103 | if not self.dll is None: 104 | params = self.dll(outputs) 105 | else: 106 | params = 0.0 107 | return outputs, edge, params -------------------------------------------------------------------------------- /regnn/models/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class Mlp(nn.Module): 6 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 7 | super().__init__() 8 | out_features = out_features or in_features 9 | hidden_features = hidden_features or in_features 10 | self.fc1 = nn.Linear(in_features, hidden_features) 11 | self.act = act_layer() 12 | self.fc2 = nn.Linear(hidden_features, out_features) 13 | self.drop = nn.Dropout(drop) 14 | 15 | def forward(self, x): 16 | x = self.fc1(x) 17 | x = self.act(x) 18 | x = self.drop(x) 19 | x = self.fc2(x) 20 | x = self.drop(x) 21 | return x 22 | 23 | class GmmDistribution(nn.Module): 24 | def __init__(self, frame_size, k=6, multi_var=False, differ_muk=False, muk_func=False): 25 | super().__init__() 26 | self.k = k 27 | self.frame_size = frame_size 28 | self.multi_var = multi_var 29 | self.differ_muk = differ_muk 30 | self.muk_func = muk_func 31 | for i in range(k): 32 | self.__setattr__(name='mean_' + str(i), value=Mlp(in_features=frame_size)) 33 | if multi_var: 34 | self.__setattr__(name='var_' + str(i), value=Mlp(in_features=frame_size)) 35 | 36 | if differ_muk: 37 | if not muk_func: 38 | self.muk = nn.Parameter(torch.ones((k, 1), requires_grad=True)) 39 | else: 40 | self.muk_func = nn.Sequential( 41 | Mlp(in_features=frame_size, hidden_features=frame_size // 2, out_features=1), 42 | nn.Softmax(dim=-1) 43 | ) 44 | 45 | def forward(self, inputs): 46 | means = [] 47 | if self.multi_var: 48 | vars = [] 49 | 50 | for i in range(self.k): 51 | mean_i_func = self.__getattr__(name='mean_' + str(i)) 52 | mean_i = mean_i_func(inputs) 53 | means.append(mean_i) 54 | if self.multi_var: 55 | var_i_func = self.__getattr__(name='var_' + str(i)) 56 | var_i = var_i_func(inputs) 57 | vars.append(var_i) 58 | 59 | mean = torch.stack(means, dim=0) # k 25 frame_size 60 | if self.multi_var: 61 | var = torch.stack(vars, dim=0) 62 | else: 63 | var = torch.ones((self.k, 25, self.frame_size)) 64 | 65 | if self.differ_muk: 66 | if self.muk_func: 67 | muk = self.muk_func(mean) 68 | else: 69 | muk = self.muk / self.muk.sum() 70 | else: 71 | muk = 1 / self.k 72 | 73 | return mean, var, muk 74 | 75 | class MultiGaussian(nn.Module): 76 | def __init__(self, frame_size, n_elements): 77 | super().__init__() 78 | self.frame_size = frame_size 79 | self.n_elements = n_elements 80 | for i in range(n_elements): 81 | self.__setattr__(name='fc_' + str(i), value=Mlp(in_features=frame_size)) 82 | 83 | self.var_func = Mlp(in_features=2 * n_elements, hidden_features=n_elements, out_features=1) 84 | 85 | def forward(self, inputs): 86 | # inputs: (B, 25, frame_size) 87 | means = [] 88 | for i in range(self.n_elements): 89 | fc_i = self.__getattr__(name='fc_' + str(i)) 90 | mean_i = fc_i(inputs[i] if len(inputs.shape) == 2 else inputs[:, i]) 91 | means.append(mean_i) 92 | 93 | means = torch.cat(means, dim=0) 94 | 95 | var = self.var_func(torch.cat((inputs.squeeze(0) if len(inputs.shape) == 3 else inputs, 96 | means), dim=0).transpose(-2, -1) 97 | ) 98 | 99 | # means: (25, frame_size), var: (1, frame_size) 100 | return means, var 101 | 102 | class Gaussian(nn.Module): 103 | def __init__(self, frame_size, n_elements): 104 | super().__init__() 105 | self.frame_size = frame_size 106 | self.n_elements = n_elements 107 | self.mean_func =Mlp(in_features=n_elements * frame_size) 108 | 109 | self.var_func = Mlp(in_features=n_elements * frame_size) 110 | 111 | def forward(self, inputs): 112 | # inputs: (B, 25, frame_size) 113 | inputs = inputs.reshape(-1, self.n_elements * self.frame_size) 114 | mean = self.mean_func(inputs) 115 | var = self.var_func(inputs) 116 | var = torch.eye(self.n_elements * self.frame_size, device='cuda:0') 117 | # mean var : B 25*frame_size 118 | return mean, var -------------------------------------------------------------------------------- /regnn/models/mhp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.compute_distance_fun import compute_distance 4 | 5 | 6 | NUM_SAMPLE=10 7 | class MHP(nn.Module): 8 | def __init__(self, p=None, c=None, m=None, no_inverse=False, dist="MSE", neighbor_pattern='nearest'): 9 | super().__init__() 10 | self.perceptual_processor = p 11 | self.cognitive_processor = c 12 | self.motor_processor = m 13 | self.cal_dist = { 14 | "DTW": compute_distance, 15 | "MSE": nn.functional.mse_loss,} 16 | 17 | self.no_inverse = no_inverse 18 | self.neighbor_pattern = neighbor_pattern 19 | 20 | def forward_features(self, video_inputs, audio_inputs): 21 | if video_inputs is None: 22 | fused_features = audio_inputs 23 | elif audio_inputs is None: 24 | fused_features = video_inputs 25 | else: 26 | # B, T, 64 27 | fused_features = self.perceptual_processor(video_inputs, audio_inputs) 28 | # cog_outputs: 4, 25, 50 ---> B, N, T 29 | # edge: [4, 8, 25, 25] it does not change with T 30 | # params: 0.0? 31 | cog_outputs, edge, params = self.cognitive_processor(fused_features) 32 | return cog_outputs, edge, params 33 | 34 | def get_nearest(self, features, edge, targets): 35 | # B, 750, 25 36 | with torch.no_grad(): 37 | if not self.no_inverse: 38 | self.motor_processor.eval() 39 | predictions = self.motor_processor.inverse(features, edge=edge) 40 | else: 41 | predictions = features 42 | B = predictions.shape[0] 43 | nearest_idx = [] 44 | for i in range(B): 45 | if len(targets[1]) == None: 46 | nearest_idx.append(0) 47 | continue 48 | 49 | pred = predictions[i].unsqueeze(0) 50 | pair_targets = targets[i] 51 | min_dist = None 52 | for i, pair_target in enumerate(pair_targets): 53 | if pair_target == None: 54 | continue 55 | dist = self.cal_dist(pred, pair_target.transpose(1, 0).unsqueeze(0)) 56 | if min_dist == None or dist < min_dist: 57 | min_dist = dist 58 | min_inx = i 59 | 60 | nearest_idx.append(min_inx) 61 | 62 | nearest_targets = [targets[i][idx].transpose(1, 0) for i, idx in enumerate(nearest_idx)] 63 | nearest_targets = torch.stack(nearest_targets, dim=0) 64 | return nearest_targets 65 | 66 | def forward(self, video_inputs, audio_inputs, targets, lengthes=None): 67 | # print('-------------------- Forward --------------------') 68 | # speaker_feature: 4, 25, 50 69 | speaker_feature, edge, params = self.forward_features(video_inputs, audio_inputs) 70 | if not self.no_inverse: 71 | if self.neighbor_pattern == 'nearest': 72 | targets = self.get_nearest(speaker_feature, edge, targets) 73 | targets.requires_grad = True 74 | elif self.neighbor_pattern == 'all': 75 | edge = torch.repeat_interleave(edge, repeats=torch.tensor(lengthes, device=edge.device), dim=0) 76 | else: 77 | targets = targets 78 | 79 | self.motor_processor.train() 80 | # Encode all appropriate real facial reactions to a GMGD distribution 81 | # listener_feature: B, N, D 82 | listener_feature, logdets = self.motor_processor(targets, edge) 83 | 84 | return speaker_feature, listener_feature, params, edge, targets, logdets 85 | 86 | else: 87 | # Decode samples to listener appropriate facial reactions 88 | listerer_feature, logdets = self.motor_processor(speaker_feature, edge) 89 | nearest_targets = targets 90 | 91 | return listerer_feature, nearest_targets, logdets 92 | 93 | def inverse(self, video_inputs, audio_inputs, cal_norm, threshold=None): 94 | speaker_feature, edge, params = self.forward_features(video_inputs, audio_inputs) 95 | if not self.no_inverse: 96 | speaker_feature = self.sample(speaker_feature, threshold) 97 | predictions = self.motor_processor.inverse(speaker_feature, edge=edge, cal_norm=cal_norm) 98 | else: 99 | speaker_feature = self.sample(speaker_feature, threshold) 100 | predictions, _ = self.motor_processor(speaker_feature, edge) 101 | return predictions.transpose(2, 1) 102 | 103 | def onlyInverseMotor(self, features, edge): 104 | print('='*50) 105 | print('--------------------Only Inverse--------------------') 106 | with torch.no_grad(): 107 | self.motor_processor.eval() 108 | outputs = self.motor_processor.inverse(features, edge=edge) 109 | 110 | return outputs 111 | 112 | def sample(self, speaker_feature, threshold=None): 113 | noise = torch.randn(speaker_feature.shape, device=speaker_feature.device) 114 | if threshold is None: 115 | return speaker_feature + noise 116 | 117 | threshold = torch.sqrt(torch.tensor([threshold], device=speaker_feature.device)) 118 | 119 | max_abs = torch.max(torch.abs(noise)) 120 | if max_abs <= threshold: 121 | return speaker_feature + noise 122 | scale = threshold / max_abs 123 | scaled_noise = noise * scale 124 | return speaker_feature + scaled_noise 125 | 126 | -------------------------------------------------------------------------------- /regnn/models/multimodel/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | # Code adapted from the fairseq repo. 8 | 9 | class MultiheadAttention(nn.Module): 10 | """Multi-headed attention. 11 | See "Attention Is All You Need" for more details. 12 | """ 13 | 14 | def __init__(self, embed_dim, num_heads, attn_dropout=0., 15 | bias=True, add_bias_kv=False, add_zero_attn=False): 16 | super().__init__() 17 | self.embed_dim = embed_dim 18 | self.num_heads = num_heads 19 | self.attn_dropout = attn_dropout 20 | self.head_dim = embed_dim // num_heads 21 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 22 | self.scaling = self.head_dim ** -0.5 23 | 24 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 25 | self.register_parameter('in_proj_bias', None) 26 | if bias: 27 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 28 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 29 | 30 | if add_bias_kv: 31 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 32 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 33 | else: 34 | self.bias_k = self.bias_v = None 35 | 36 | self.add_zero_attn = add_zero_attn 37 | 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | nn.init.xavier_uniform_(self.in_proj_weight) 42 | nn.init.xavier_uniform_(self.out_proj.weight) 43 | if self.in_proj_bias is not None: 44 | nn.init.constant_(self.in_proj_bias, 0.) 45 | nn.init.constant_(self.out_proj.bias, 0.) 46 | if self.bias_k is not None: 47 | nn.init.xavier_normal_(self.bias_k) 48 | if self.bias_v is not None: 49 | nn.init.xavier_normal_(self.bias_v) 50 | 51 | def forward(self, query, key, value, attn_mask=None): 52 | """Input shape: Time x Batch x Channel 53 | Self-attention can be implemented by passing in the same arguments for 54 | query, key and value. Timesteps can be masked by supplying a T x T mask in the 55 | `attn_mask` argument. Padding elements can be excluded from 56 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape: 57 | batch x src_len, where padding elements are indicated by 1s. 58 | """ 59 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 60 | kv_same = key.data_ptr() == value.data_ptr() 61 | 62 | tgt_len, bsz, embed_dim = query.size() 63 | assert embed_dim == self.embed_dim 64 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 65 | assert key.size() == value.size() 66 | 67 | aved_state = None 68 | 69 | if qkv_same: 70 | # self-attention 71 | q, k, v = self.in_proj_qkv(query) 72 | elif kv_same: 73 | # encoder-decoder attention 74 | q = self.in_proj_q(query) 75 | 76 | if key is None: 77 | assert value is None 78 | k = v = None 79 | else: 80 | k, v = self.in_proj_kv(key) 81 | else: 82 | q = self.in_proj_q(query) 83 | k = self.in_proj_k(key) 84 | v = self.in_proj_v(value) 85 | q = q * self.scaling 86 | 87 | if self.bias_k is not None: 88 | assert self.bias_v is not None 89 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 90 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 91 | if attn_mask is not None: 92 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 93 | 94 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 95 | if k is not None: 96 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 97 | if v is not None: 98 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 99 | 100 | src_len = k.size(1) 101 | 102 | if self.add_zero_attn: 103 | src_len += 1 104 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 105 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 106 | if attn_mask is not None: 107 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 108 | 109 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 110 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 111 | 112 | if attn_mask is not None: 113 | try: 114 | attn_weights += attn_mask.unsqueeze(0) 115 | except: 116 | print(attn_weights.shape) 117 | print(attn_mask.unsqueeze(0).shape) 118 | assert False 119 | 120 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 121 | # attn_weights = F.relu(attn_weights) 122 | # attn_weights = attn_weights / torch.max(attn_weights) 123 | attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training) 124 | 125 | attn = torch.bmm(attn_weights, v) 126 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 127 | 128 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 129 | attn = self.out_proj(attn) 130 | 131 | # average attention weights over heads 132 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 133 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 134 | return attn, attn_weights 135 | 136 | def in_proj_qkv(self, query): 137 | return self._in_proj(query).chunk(3, dim=-1) 138 | 139 | def in_proj_kv(self, key): 140 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 141 | 142 | def in_proj_q(self, query, **kwargs): 143 | return self._in_proj(query, end=self.embed_dim, **kwargs) 144 | 145 | def in_proj_k(self, key): 146 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 147 | 148 | def in_proj_v(self, value): 149 | return self._in_proj(value, start=2 * self.embed_dim) 150 | 151 | def _in_proj(self, input, start=0, end=None, **kwargs): 152 | weight = kwargs.get('weight', self.in_proj_weight) 153 | bias = kwargs.get('bias', self.in_proj_bias) 154 | weight = weight[start:end, :] 155 | if bias is not None: 156 | bias = bias[start:end] 157 | return F.linear(input, weight, bias) 158 | -------------------------------------------------------------------------------- /regnn/models/multimodel/position_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # Code adapted from the fairseq repo. 7 | 8 | def make_positions(tensor, padding_idx, left_pad): 9 | """Replace non-padding symbols with their position numbers. 10 | Position numbers begin at padding_idx+1. 11 | Padding symbols are ignored, but it is necessary to specify whether padding 12 | is added on the left side (left_pad=True) or right side (left_pad=False). 13 | """ 14 | max_pos = padding_idx + 1 + tensor.size(1) 15 | device = tensor.get_device() 16 | buf_name = f'range_buf_{device}' 17 | if not hasattr(make_positions, buf_name): 18 | setattr(make_positions, buf_name, tensor.new()) 19 | setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor)) 20 | if getattr(make_positions, buf_name).numel() < max_pos: 21 | torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name)) 22 | mask = tensor.ne(padding_idx) 23 | positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor) 24 | if left_pad: 25 | positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) 26 | new_tensor = tensor.clone() 27 | return new_tensor.masked_scatter_(mask, positions[mask]).long() 28 | 29 | 30 | class SinusoidalPositionalEmbedding(nn.Module): 31 | """This module produces sinusoidal positional embeddings of any length. 32 | Padding symbols are ignored, but it is necessary to specify whether padding 33 | is added on the left side (left_pad=True) or right side (left_pad=False). 34 | """ 35 | 36 | def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128): 37 | super().__init__() 38 | self.embedding_dim = embedding_dim 39 | self.padding_idx = padding_idx 40 | self.left_pad = left_pad 41 | self.weights = dict() # device --> actual weight; due to nn.DataParallel :-( 42 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 43 | 44 | @staticmethod 45 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 46 | """Build sinusoidal embeddings. 47 | This matches the implementation in tensor2tensor, but differs slightly 48 | from the description in Section 3.5 of "Attention Is All You Need". 49 | """ 50 | half_dim = embedding_dim // 2 51 | emb = math.log(10000) / (half_dim - 1) 52 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 53 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 55 | if embedding_dim % 2 == 1: 56 | # zero pad 57 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 58 | if padding_idx is not None: 59 | emb[padding_idx, :] = 0 60 | return emb 61 | 62 | def forward(self, input): 63 | """Input is expected to be of size [bsz x seqlen].""" 64 | bsz, seq_len = input.size() 65 | max_pos = self.padding_idx + 1 + seq_len 66 | device = input.get_device() 67 | if device not in self.weights or max_pos > self.weights[device].size(0): 68 | # recompute/expand embeddings if needed 69 | self.weights[device] = SinusoidalPositionalEmbedding.get_embedding( 70 | max_pos, 71 | self.embedding_dim, 72 | self.padding_idx, 73 | ) 74 | self.weights[device] = self.weights[device].type_as(self._float_tensor) 75 | positions = make_positions(input, self.padding_idx, self.left_pad) 76 | return self.weights[device].index_select(0, positions.reshape(-1)).view(bsz, seq_len, -1).detach() 77 | 78 | def max_positions(self): 79 | """Maximum number of supported positions.""" 80 | return int(1e5) # an arbitrary large number -------------------------------------------------------------------------------- /regnn/models/multimodel/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .position_embedding import SinusoidalPositionalEmbedding 5 | from .multihead_attention import MultiheadAttention 6 | import math 7 | 8 | 9 | class TransformerEncoder(nn.Module): 10 | """ 11 | Transformer encoder consisting of *args.encoder_layers* layers. Each layer 12 | is a :class:`TransformerEncoderLayer`. 13 | Args: 14 | embed_tokens (torch.nn.Embedding): input embedding 15 | num_heads (int): number of heads 16 | layers (int): number of layers 17 | attn_dropout (float): dropout applied on the attention weights 18 | relu_dropout (float): dropout applied on the first layer of the residual block 19 | res_dropout (float): dropout applied on the residual block 20 | attn_mask (bool): whether to apply mask on the attention weights 21 | """ 22 | 23 | def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0, 24 | embed_dropout=0.0, attn_mask=False): 25 | super().__init__() 26 | self.dropout = embed_dropout # Embedding dropout 27 | self.attn_dropout = attn_dropout 28 | self.embed_dim = embed_dim 29 | self.embed_scale = math.sqrt(embed_dim) 30 | self.embed_positions = SinusoidalPositionalEmbedding(embed_dim) 31 | 32 | self.attn_mask = attn_mask 33 | 34 | self.layers = nn.ModuleList([]) 35 | for layer in range(layers): 36 | new_layer = TransformerEncoderLayer(embed_dim, 37 | num_heads=num_heads, 38 | attn_dropout=attn_dropout, 39 | relu_dropout=relu_dropout, 40 | res_dropout=res_dropout, 41 | attn_mask=attn_mask) 42 | self.layers.append(new_layer) 43 | 44 | self.register_buffer('version', torch.Tensor([2])) 45 | self.normalize = True 46 | if self.normalize: 47 | self.layer_norm = LayerNorm(embed_dim) 48 | 49 | def forward(self, x_in, x_in_k = None, x_in_v = None): 50 | """ 51 | Args: 52 | x_in (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)` 53 | x_in_k (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)` 54 | x_in_v (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)` 55 | Returns: 56 | dict: 57 | - **encoder_out** (Tensor): the last encoder layer's output of 58 | shape `(src_len, batch, embed_dim)` 59 | - **encoder_padding_mask** (ByteTensor): the positions of 60 | padding elements of shape `(batch, src_len)` 61 | """ 62 | # embed tokens and positions 63 | x = self.embed_scale * x_in 64 | if self.embed_positions is not None: 65 | x += self.embed_positions(x_in.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding 66 | x = F.dropout(x, p=self.dropout, training=self.training) 67 | 68 | if x_in_k is not None and x_in_v is not None: 69 | # embed tokens and positions 70 | x_k = self.embed_scale * x_in_k 71 | x_v = self.embed_scale * x_in_v 72 | if self.embed_positions is not None: 73 | x_k += self.embed_positions(x_in_k.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding 74 | x_v += self.embed_positions(x_in_v.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding 75 | x_k = F.dropout(x_k, p=self.dropout, training=self.training) 76 | x_v = F.dropout(x_v, p=self.dropout, training=self.training) 77 | 78 | # encoder layers 79 | intermediates = [x] 80 | for layer in self.layers: 81 | if x_in_k is not None and x_in_v is not None: 82 | x = layer(x, x_k, x_v) 83 | else: 84 | x = layer(x) 85 | intermediates.append(x) 86 | 87 | if self.normalize: 88 | x = self.layer_norm(x) 89 | 90 | return x 91 | 92 | def max_positions(self): 93 | """Maximum input length supported by the encoder.""" 94 | if self.embed_positions is None: 95 | return self.max_source_positions 96 | return min(self.max_source_positions, self.embed_positions.max_positions()) 97 | 98 | 99 | class TransformerEncoderLayer(nn.Module): 100 | """Encoder layer block. 101 | In the original paper each operation (multi-head attention or FFN) is 102 | postprocessed with: `dropout -> add residual -> layernorm`. In the 103 | tensor2tensor code they suggest that learning is more robust when 104 | preprocessing each layer with layernorm and postprocessing with: 105 | `dropout -> add residual`. We default to the approach in the paper, but the 106 | tensor2tensor approach can be enabled by setting 107 | *args.encoder_normalize_before* to ``True``. 108 | Args: 109 | embed_dim: Embedding dimension 110 | """ 111 | 112 | def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1, 113 | attn_mask=False): 114 | super().__init__() 115 | self.embed_dim = embed_dim 116 | self.num_heads = num_heads 117 | 118 | self.self_attn = MultiheadAttention( 119 | embed_dim=self.embed_dim, 120 | num_heads=self.num_heads, 121 | attn_dropout=attn_dropout 122 | ) 123 | self.attn_mask = attn_mask 124 | 125 | self.relu_dropout = relu_dropout 126 | self.res_dropout = res_dropout 127 | self.normalize_before = True 128 | 129 | self.fc1 = Linear(self.embed_dim, 4*self.embed_dim) # The "Add & Norm" part in the paper 130 | self.fc2 = Linear(4*self.embed_dim, self.embed_dim) 131 | self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)]) 132 | 133 | def forward(self, x, x_k=None, x_v=None): 134 | """ 135 | Args: 136 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` 137 | encoder_padding_mask (ByteTensor): binary ByteTensor of shape 138 | `(batch, src_len)` where padding elements are indicated by ``1``. 139 | x_k (Tensor): same as x 140 | x_v (Tensor): same as x 141 | Returns: 142 | encoded output of shape `(batch, src_len, embed_dim)` 143 | """ 144 | residual = x 145 | x = self.maybe_layer_norm(0, x, before=True) 146 | mask = buffered_future_mask(x, x_k) if self.attn_mask else None 147 | if x_k is None and x_v is None: 148 | x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask) 149 | else: 150 | x_k = self.maybe_layer_norm(0, x_k, before=True) 151 | x_v = self.maybe_layer_norm(0, x_v, before=True) 152 | x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask) 153 | x = F.dropout(x, p=self.res_dropout, training=self.training) 154 | x = residual + x 155 | x = self.maybe_layer_norm(0, x, after=True) 156 | 157 | residual = x 158 | x = self.maybe_layer_norm(1, x, before=True) 159 | x = F.relu(self.fc1(x)) 160 | x = F.dropout(x, p=self.relu_dropout, training=self.training) 161 | x = self.fc2(x) 162 | x = F.dropout(x, p=self.res_dropout, training=self.training) 163 | x = residual + x 164 | x = self.maybe_layer_norm(1, x, after=True) 165 | return x 166 | 167 | def maybe_layer_norm(self, i, x, before=False, after=False): 168 | assert before ^ after 169 | if after ^ self.normalize_before: 170 | return self.layer_norms[i](x) 171 | else: 172 | return x 173 | 174 | def fill_with_neg_inf(t): 175 | """FP16-compatible function that fills a tensor with -inf.""" 176 | return t.float().fill_(float('-inf')).type_as(t) 177 | 178 | 179 | def buffered_future_mask(tensor, tensor2=None): 180 | dim1 = dim2 = tensor.size(0) 181 | if tensor2 is not None: 182 | dim2 = tensor2.size(0) 183 | future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1)) 184 | if tensor.is_cuda: 185 | future_mask = future_mask.cuda() 186 | return future_mask[:dim1, :dim2] 187 | 188 | 189 | def Linear(in_features, out_features, bias=True): 190 | m = nn.Linear(in_features, out_features, bias) 191 | nn.init.xavier_uniform_(m.weight) 192 | if bias: 193 | nn.init.constant_(m.bias, 0.) 194 | return m 195 | 196 | 197 | def LayerNorm(embedding_dim): 198 | m = nn.LayerNorm(embedding_dim) 199 | return m 200 | 201 | 202 | if __name__ == '__main__': 203 | encoder = TransformerEncoder(300, 4, 2) 204 | x = torch.tensor(torch.rand(20, 2, 300)) 205 | print(encoder(x).shape) 206 | -------------------------------------------------------------------------------- /regnn/models/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from .swin_transformer import SwinTransformer 5 | from .torchvggish.vggish import VGGish 6 | from .MULTModel import MULTModel 7 | # net_to_call(embed_dim = 96, depths = [2, 2, 6, 2], num_heads = [3, 6, 12, 24], window_size = 7, drop_path_rate = 0.2) 8 | 9 | class VideoProcessor(nn.Module): 10 | def __init__(self, base_type='Swin', frame_size=50, pretrained=False): 11 | super().__init__() 12 | self.frame_size = frame_size 13 | if base_type == 'Swin': 14 | self.model = SwinTransformer(embed_dim = 96, depths = [2, 2, 6, 2], num_heads = [3, 6, 12, 24], 15 | window_size = 7, drop_path_rate = 0.2) 16 | else: 17 | raise KeyError ('Invalid model types') 18 | if pretrained: 19 | self.model.load_state_dict(torch.load(r"/scratch/recface/hz204/react_data/pretrained/swin_fer.pth", 20 | map_location='cpu')) 21 | # ["model"]) 22 | 23 | def forward(self, inputs): 24 | N, C, H, W = inputs.shape 25 | # print(N, C, H, W) 26 | assert N==self.frame_size 27 | # outputs = self.model(inputs) 28 | outputs = self.model.forward_features(inputs) 29 | assert outputs.shape[0]==N 30 | return outputs 31 | 32 | class AudioProcessor(nn.Module): 33 | def __init__(self, base_type='VGGish', ): 34 | super().__init__() 35 | if base_type == 'VGGish': 36 | # 如果输入是tensor,preprocess=False,如果输入是array/str,preprocess=True 37 | self.model = VGGish(preprocess=False) 38 | else: 39 | raise KeyError('Invalid model types') 40 | 41 | def forward(self, file_name): 42 | outputs = self.model(file_name) 43 | return outputs 44 | 45 | 46 | class PercepProcessor(nn.Module): 47 | def __init__(self, v_type='Swin', a_type='VGGish', only_fuse=False): 48 | super().__init__() 49 | if not only_fuse: 50 | self.v_processor = VideoProcessor(base_type=v_type) 51 | self.a_processor = AudioProcessor(base_type=a_type) 52 | self.fuse_model = MULTModel() 53 | self.only_fuse = only_fuse 54 | 55 | # B T D 56 | # video_inputs: B, T, 768 57 | # audio_inputs: B, T, 128 58 | # outputs: B, T, 64 59 | def forward(self, video_inputs, audio_inputs): 60 | if not self.only_fuse: 61 | video_inputs = self.v_processor(video_inputs) 62 | audio_inputs = self.a_processor(audio_inputs) 63 | video_inputs = video_inputs.unsqueeze(0) 64 | audio_inputs = audio_inputs.unsqueeze(0) 65 | 66 | # print('*'*50) 67 | # print('V-SEQ:', v_seq.shape) 68 | # print('A-SEQ:', a_seq.shape) 69 | # print('*' * 50) 70 | # print('video:', video_inputs.shape) 71 | # print('audio:', audio_inputs.shape) 72 | # fused_feature: T, B, 64 73 | fused_feature = self.fuse_model(video_inputs, audio_inputs) 74 | 75 | return fused_feature.transpose(1, 0) 76 | 77 | 78 | -------------------------------------------------------------------------------- /regnn/models/torchvggish/vggish.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch import hub 5 | 6 | from . import vggish_input, vggish_params 7 | 8 | 9 | class VGG(nn.Module): 10 | def __init__(self, features): 11 | super(VGG, self).__init__() 12 | self.features = features 13 | self.embeddings = nn.Sequential( 14 | nn.Linear(512 * 4 * 6, 4096), 15 | nn.ReLU(True), 16 | nn.Linear(4096, 4096), 17 | nn.ReLU(True), 18 | nn.Linear(4096, 128), 19 | nn.ReLU(True)) 20 | 21 | def forward(self, x): 22 | x = self.features(x) 23 | 24 | # Transpose the output from features to 25 | # remain compatible with vggish embeddings 26 | x = torch.transpose(x, 1, 3) 27 | x = torch.transpose(x, 1, 2) 28 | x = x.contiguous() 29 | x = x.view(x.size(0), -1) 30 | 31 | return self.embeddings(x) 32 | 33 | 34 | class Postprocessor(nn.Module): 35 | """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a 36 | numpy array in order to preserve the gradient. 37 | 38 | "The initial release of AudioSet included 128-D VGGish embeddings for each 39 | segment of AudioSet. These released embeddings were produced by applying 40 | a PCA transformation (technically, a whitening transform is included as well) 41 | and 8-bit quantization to the raw embedding output from VGGish, in order to 42 | stay compatible with the YouTube-8M project which provides visual embeddings 43 | in the same format for a large set of YouTube videos. This class implements 44 | the same PCA (with whitening) and quantization transformations." 45 | """ 46 | 47 | def __init__(self): 48 | """Constructs a postprocessor.""" 49 | super(Postprocessor, self).__init__() 50 | # Create empty matrix, for user's state_dict to load 51 | self.pca_eigen_vectors = torch.empty( 52 | (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), 53 | dtype=torch.float, 54 | ) 55 | self.pca_means = torch.empty( 56 | (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float 57 | ) 58 | 59 | self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) 60 | self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) 61 | 62 | def postprocess(self, embeddings_batch): 63 | """Applies tensor postprocessing to a batch of embeddings. 64 | 65 | Args: 66 | embeddings_batch: An tensor of shape [batch_size, embedding_size] 67 | containing output from the embedding layer of VGGish. 68 | 69 | Returns: 70 | A tensor of the same shape as the input, containing the PCA-transformed, 71 | quantized, and clipped version of the input. 72 | """ 73 | assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( 74 | embeddings_batch.shape, 75 | ) 76 | assert ( 77 | embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE 78 | ), "Bad batch shape: %r" % (embeddings_batch.shape,) 79 | 80 | # Apply PCA. 81 | # - Embeddings come in as [batch_size, embedding_size]. 82 | # - Transpose to [embedding_size, batch_size]. 83 | # - Subtract pca_means column vector from each column. 84 | # - Premultiply by PCA matrix of shape [output_dims, input_dims] 85 | # where both are are equal to embedding_size in our case. 86 | # - Transpose result back to [batch_size, embedding_size]. 87 | pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() 88 | 89 | # Quantize by: 90 | # - clipping to [min, max] range 91 | clipped_embeddings = torch.clamp( 92 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL 93 | ) 94 | # - convert to 8-bit in range [0.0, 255.0] 95 | quantized_embeddings = torch.round( 96 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) 97 | * ( 98 | 255.0 99 | / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) 100 | ) 101 | ) 102 | return torch.squeeze(quantized_embeddings) 103 | 104 | def forward(self, x): 105 | return self.postprocess(x) 106 | 107 | 108 | def make_layers(): 109 | layers = [] 110 | in_channels = 1 111 | for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: 112 | if v == "M": 113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 114 | else: 115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 116 | layers += [conv2d, nn.ReLU(inplace=True)] 117 | in_channels = v 118 | return nn.Sequential(*layers) 119 | 120 | 121 | def _vgg(): 122 | return VGG(make_layers()) 123 | 124 | 125 | # def _spectrogram(): 126 | # config = dict( 127 | # sr=16000, 128 | # n_fft=400, 129 | # n_mels=64, 130 | # hop_length=160, 131 | # window="hann", 132 | # center=False, 133 | # pad_mode="reflect", 134 | # htk=True, 135 | # fmin=125, 136 | # fmax=7500, 137 | # output_format='Magnitude', 138 | # # device=device, 139 | # ) 140 | # return Spectrogram.MelSpectrogram(**config) 141 | 142 | model_urls = { 143 | 'vggish': 'https://github.com/harritaylor/torchvggish/' 144 | 'releases/download/v0.1/vggish-10086976.pth', 145 | 'pca': 'https://github.com/harritaylor/torchvggish/' 146 | 'releases/download/v0.1/vggish_pca_params-970ea276.pth' 147 | } 148 | 149 | class VGGish(VGG): 150 | def __init__(self, urls=model_urls, device=None, pretrained=True, preprocess=False, postprocess=True, progress=True): 151 | super().__init__(make_layers()) 152 | if pretrained: 153 | state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress) 154 | super().load_state_dict(state_dict) 155 | 156 | if device is None: 157 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 158 | self.device = device 159 | self.preprocess = preprocess 160 | self.postprocess = postprocess 161 | if self.postprocess: 162 | self.pproc = Postprocessor() 163 | if pretrained: 164 | state_dict = hub.load_state_dict_from_url(urls['pca'], progress=progress) 165 | # TODO: Convert the state_dict to torch 166 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( 167 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float 168 | ) 169 | state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( 170 | state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float 171 | ) 172 | 173 | self.pproc.load_state_dict(state_dict) 174 | self.to(self.device) 175 | 176 | def forward(self, x, fs=None): 177 | if self.preprocess: 178 | x = self._preprocess(x, fs) 179 | x = x.to(self.device) 180 | x = VGG.forward(self, x) 181 | if self.postprocess: 182 | x = self._postprocess(x) 183 | return x 184 | 185 | def _preprocess(self, x, fs): 186 | if isinstance(x, np.ndarray): 187 | x = vggish_input.waveform_to_examples(x, fs) 188 | elif isinstance(x, str): 189 | x = vggish_input.wavfile_to_examples(x) 190 | else: 191 | raise AttributeError 192 | return x 193 | 194 | def _postprocess(self, x): 195 | return self.pproc(x) 196 | -------------------------------------------------------------------------------- /regnn/models/torchvggish/vggish_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute input examples for VGGish from audio waveform.""" 17 | 18 | # Modification: Return torch tensors rather than numpy arrays 19 | import torch 20 | 21 | import numpy as np 22 | import resampy 23 | 24 | from . import mel_features 25 | from . import vggish_params 26 | 27 | import soundfile as sf 28 | 29 | 30 | def waveform_to_examples(data, sample_rate, return_tensor=True): 31 | """Converts audio waveform into an array of examples for VGGish. 32 | 33 | Args: 34 | data: np.array of either one dimension (mono) or two dimensions 35 | (multi-channel, with the outer dimension representing channels). 36 | Each sample is generally expected to lie in the range [-1.0, +1.0], 37 | although this is not required. 38 | sample_rate: Sample rate of data. 39 | return_tensor: Return data as a Pytorch tensor ready for VGGish 40 | 41 | Returns: 42 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents 43 | a sequence of examples, each of which contains a patch of log mel 44 | spectrogram, covering num_frames frames of audio and num_bands mel frequency 45 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. 46 | 47 | """ 48 | # Convert to mono. 49 | if len(data.shape) > 1: 50 | data = np.mean(data, axis=1) 51 | # Resample to the rate assumed by VGGish. 52 | if sample_rate != vggish_params.SAMPLE_RATE: 53 | # print('sample_rate:', sample_rate) 54 | # print('SAMPLE_RATE:', vggish_params.SAMPLE_RATE) 55 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) 56 | 57 | # Compute log mel spectrogram features. 58 | log_mel = mel_features.log_mel_spectrogram( 59 | data, 60 | audio_sample_rate=vggish_params.SAMPLE_RATE, 61 | log_offset=vggish_params.LOG_OFFSET, 62 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, 63 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, 64 | num_mel_bins=vggish_params.NUM_MEL_BINS, 65 | lower_edge_hertz=vggish_params.MEL_MIN_HZ, 66 | upper_edge_hertz=vggish_params.MEL_MAX_HZ) 67 | 68 | # Frame features into examples. 69 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS 70 | example_window_length = int(round( 71 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) 72 | example_hop_length = int(round( 73 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) 74 | log_mel_examples = mel_features.frame( 75 | log_mel, 76 | window_length=example_window_length, 77 | hop_length=example_hop_length) 78 | 79 | if return_tensor: 80 | log_mel_examples = torch.tensor( 81 | log_mel_examples, requires_grad=True)[:, None, :, :].float() 82 | 83 | return log_mel_examples 84 | 85 | 86 | def wavfile_to_examples(wav_file, return_tensor=True): 87 | """Convenience wrapper around waveform_to_examples() for a common WAV format. 88 | 89 | Args: 90 | wav_file: String path to a file, or a file-like object. The file 91 | is assumed to contain WAV audio data with signed 16-bit PCM samples. 92 | torch: Return data as a Pytorch tensor ready for VGGish 93 | 94 | Returns: 95 | See waveform_to_examples. 96 | """ 97 | wav_data, sr = sf.read(wav_file, dtype='int16') 98 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype 99 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] 100 | return waveform_to_examples(samples, sr, return_tensor) 101 | -------------------------------------------------------------------------------- /regnn/models/torchvggish/vggish_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Global parameters for the VGGish model. 17 | 18 | See vggish_slim.py for more information. 19 | """ 20 | 21 | # Architectural constants. 22 | NUM_FRAMES = 750 # Frames in input mel-spectrogram patch. 23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. 24 | EMBEDDING_SIZE = 128 # Size of embedding layer. 25 | 26 | # Hyperparameters used in feature and example generation. 27 | SAMPLE_RATE = 24000 28 | # 傅里叶变换的 窗长 和 窗移 29 | STFT_WINDOW_LENGTH_SECONDS = 0.025 30 | STFT_HOP_LENGTH_SECONDS = 0.010 / 24 31 | NUM_MEL_BINS = NUM_BANDS 32 | MEL_MIN_HZ = 125 33 | MEL_MAX_HZ = 7500 34 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. 35 | EXAMPLE_WINDOW_SECONDS = 0.04 # Each example contains 96 10ms frames 36 | EXAMPLE_HOP_SECONDS = 0.04 # with zero overlap. 37 | 38 | # Parameters used for embedding postprocessing. 39 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' 40 | PCA_MEANS_NAME = 'pca_means' 41 | QUANTIZE_MIN_VAL = -2.0 42 | QUANTIZE_MAX_VAL = +2.0 43 | 44 | # Hyperparameters used in training. 45 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 46 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. 47 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 48 | 49 | # Names of ops, tensors, and features. 50 | INPUT_OP_NAME = 'vggish/input_features' 51 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' 52 | OUTPUT_OP_NAME = 'vggish/embedding' 53 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' 54 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' 55 | -------------------------------------------------------------------------------- /regnn/scripts/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python train.py \ 4 | --test \ 5 | --model-pth "Gmm-logs/mhp-epoch95-seed1.pth" \ 6 | --neighbor-pattern 'nearest' \ 7 | --seed 1 -------------------------------------------------------------------------------- /regnn/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python train.py \ 4 | --logs-dir 'Gmm-logs' \ 5 | --lr 0.0001 \ 6 | --gamma 0.1 \ 7 | --warmup-factor 0.01 \ 8 | --milestones 9 \ 9 | --batch-size 128 \ 10 | --layers 2 \ 11 | --act 'ELU' \ 12 | --seed 1 \ 13 | --train-iters 100 \ 14 | --norm \ 15 | --neighbor-pattern 'all' \ 16 | --convert-type 'direct' \ 17 | --loss-mid -------------------------------------------------------------------------------- /regnn/tool/matrix_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | import os 5 | 6 | 7 | def parse_arg(): 8 | parser = argparse.ArgumentParser(description='PyTorch Training') 9 | parser.add_argument('--dataset-path', default="./data", type=str, help="dataset path") 10 | parser.add_argument('--partition', default="val", type=str, help="dataset partition") 11 | args = parser.parse_args() 12 | return args 13 | 14 | args = parse_arg() 15 | root_path = args.dataset_path 16 | data_path = pd.read_csv(os.path.join(root_path, 'data_indices.csv'), header=None, delimiter=',') 17 | data_path = data_path.drop(0) 18 | all_matrix = np.load(os.path.join(root_path, 'Approprirate_facial_reaction.npy')) 19 | 20 | all_matrix_w = all_matrix.shape[0] 21 | print("all_matrix shape:", all_matrix.shape) 22 | 23 | val_path = pd.read_csv(os.path.join(root_path, str(args.partition)+'.csv'), header=None, delimiter=',') 24 | val_path = val_path.drop(0) 25 | 26 | index_list = [] 27 | _sum = 0 28 | for index, row in val_path.iterrows(): 29 | _sum += 1 30 | flag = 0 31 | items = row[1].split('/') 32 | if len(items)>4: 33 | continue # Skip the UDIVA 34 | else: 35 | path_val_item = os.path.join(items[0].upper(),items[1],items[3]) 36 | for index_2, row2 in data_path.iterrows(): 37 | path_data_index = os.path.join(row2[0].upper(), row2[1], row2[2]) 38 | if path_val_item == path_data_index: 39 | index_list.append(index_2-1) 40 | flag = 1 41 | 42 | l_m = len(index_list) 43 | new_matrix = np.zeros((l_m*2, l_m*2)) 44 | for index, item in enumerate(index_list): 45 | for j, item_j in enumerate(index_list): 46 | new_matrix[index, j] = all_matrix[item, item_j] 47 | new_matrix[index, j+l_m] = all_matrix[item, item_j + all_matrix_w//2] 48 | new_matrix[index+l_m, j] = all_matrix[item + all_matrix_w//2, item_j] 49 | new_matrix[index+l_m, j+l_m] = all_matrix[item + all_matrix_w//2, item_j + all_matrix_w//2] 50 | 51 | 52 | np.save(os.path.join(root_path, 'neighbour_emotion_' + str(args.partition) + '.npy'), new_matrix) 53 | print(new_matrix.shape) -------------------------------------------------------------------------------- /regnn/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import sys 4 | import json 5 | import torch 6 | import random 7 | import argparse 8 | import numpy as np 9 | import os.path as osp 10 | import pandas as pd 11 | from trainers import Trainer 12 | from datasets import ActionData 13 | from utils.logging import Logger 14 | from torch.backends import cudnn 15 | from utils.meters import AverageMeter 16 | from torch.utils.data import DataLoader 17 | from utils.lr_scheduler import WarmupMultiStepLR 18 | from models import CognitiveProcessor, PercepProcessor, MHP, LipschitzGraph 19 | 20 | def set_seed(seed): 21 | if seed == 0: 22 | return 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | cudnn.deterministic = True 27 | 28 | def train(args): 29 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 30 | print("==========\nArgs:{}\n==========".format(args)) 31 | 32 | set_seed(args.seed) 33 | 34 | num_frames = args.num_frames # 50 35 | stride = args.stride # 25 36 | edge_dim = args.edge_dim # 8 37 | num_neighbor = args.neighbors # 6 38 | 39 | Cog = CognitiveProcessor(input_dim=64, convert_type=args.convert_type, num_features=num_frames, 40 | n_channels=edge_dim, k=num_neighbor) 41 | Per = PercepProcessor(only_fuse=True) 42 | Mot = LipschitzGraph(edge_channel=edge_dim, n_layers=args.layers, act_type=args.act, 43 | num_features=num_frames, norm=args.norm, get_logdets=args.get_logdets) 44 | model = MHP(p=Per, c=Cog, m=Mot, no_inverse=args.no_inverse, neighbor_pattern=args.neighbor_pattern) 45 | model = model.cuda() 46 | 47 | train_path = pd.read_csv(os.path.join(args.data_dir, 'train.csv'), header=None, delimiter=',') 48 | train_path = train_path.drop(0) 49 | speaker_path = [path for path in list(train_path.values[:, 1])] + [path for path in list(train_path.values[:, 2])] 50 | listener_path = [path for path in list(train_path.values[:, 2])] + [path for path in list(train_path.values[:, 1])] 51 | 52 | train_neighbour_path = os.path.join(args.data_dir, 'neighbour_emotion_train.npy') 53 | train_neighbour = np.load(train_neighbour_path) 54 | 55 | neighbors = { 56 | 'speaker_path': speaker_path, 57 | 'listener_path': listener_path, 58 | 'neighbors': train_neighbour 59 | } 60 | 61 | dataset = ActionData(root=args.data_dir, data_type='train', neighbors=None, 62 | neighbor_pattern=args.neighbor_pattern, num_frames=num_frames, stride=stride) 63 | 64 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 65 | 66 | trainer = Trainer(model=model, neighbors=neighbors, loss_name=args.loss_name, 67 | no_inverse=args.no_inverse, neighbor_pattern=args.neighbor_pattern, 68 | num_frames=num_frames, stride=stride, loss_mid=args.loss_mid, 69 | cal_logdets=args.get_logdets) 70 | 71 | params = [] 72 | for key, value in model.named_parameters(): 73 | if not value.requires_grad: 74 | continue 75 | params += [{"params": [value], "lr": args.lr, "weight_decay": args.weight_decay}] 76 | 77 | optimizer = torch.optim.Adam(params) 78 | lr_scheduler = WarmupMultiStepLR(optimizer, gamma=args.gamma, warmup_factor=args.warmup_factor, 79 | milestones=args.milestones, warmup_iters=args.warmup_step) 80 | 81 | for epoch in range(100): 82 | lr_scheduler.step(epoch) 83 | print('Epoch [{}] LR [{:.6f}]'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr'])) 84 | trainer.train(epoch=epoch, dataloader=dataloader, optimizer=optimizer, train_iters=args.train_iters) 85 | if epoch > 0 and epoch % 5 == 0: 86 | torch.save(model, osp.join(args.logs_dir, "mhp-epoch{0}-seed{1}.pth").format(epoch, args.seed)) 87 | 88 | def test(): 89 | sys.stdout = Logger(osp.join(args.logs_dir, 'test.txt')) 90 | print("==========\nArgs:{}\n==========".format(args)) 91 | 92 | set_seed(args.seed) 93 | 94 | num_frames = args.num_frames 95 | stride = args.stride 96 | 97 | test_batch = len([i for i in range(0, 750 - num_frames + 1, stride)]) 98 | 99 | model_pth = args.model_pth 100 | 101 | save_base = os.path.join(args.data_dir, 'outputs', model_pth.split('/')[-2]) 102 | if not os.path.isdir(save_base): 103 | os.makedirs(save_base) 104 | 105 | testset = ActionData(root=args.data_dir, data_type='test', neighbors=None, 106 | neighbor_pattern=args.neighbor_pattern, num_frames=num_frames, stride=stride) 107 | testloader = DataLoader(testset, batch_size=test_batch, shuffle=False, ) 108 | 109 | model = torch.load(model_pth) 110 | model = model.cuda() 111 | 112 | val_path = pd.read_csv(os.path.join(args.data_dir, 'test.csv'), header=None, delimiter=',') 113 | val_path = val_path.drop(0) 114 | speaker_path = [path for path in list(val_path.values[:, 1])] + [path for path in list(val_path.values[:, 2])] 115 | listener_path = [path for path in list(val_path.values[:, 2])] + [path for path in list(val_path.values[:, 1])] 116 | 117 | val_neighbour_path = os.path.join(args.data_dir, 'neighbour_emotion_test.npy') 118 | val_neighbour = np.load(val_neighbour_path) 119 | 120 | neighbors = { 121 | 'speaker_path': speaker_path, 122 | 'listener_path': listener_path, 123 | 'neighbors': val_neighbour 124 | } 125 | 126 | trainer = Trainer(model=model, neighbors=neighbors, neighbor_pattern=args.neighbor_pattern, 127 | no_inverse=args.no_inverse) 128 | 129 | trainer.threshold = 0.06 130 | trainer.test(testloader, modify=args.modify, save_base=save_base) 131 | 132 | 133 | 134 | 135 | if __name__ == '__main__': 136 | parser = argparse.ArgumentParser(description="Actiton Generation") 137 | # pattern 138 | parser.add_argument('--test', action='store_true') 139 | # data 140 | parser.add_argument('-b', '--batch-size', type=int, default=64) 141 | parser.add_argument('-j', '--workers', type=int, default=8) 142 | # model 143 | parser.add_argument('--norm', action='store_true') 144 | parser.add_argument('--layers', type=int, default=6) 145 | parser.add_argument('--act', type=str, default='ReLU') 146 | parser.add_argument('--no-inverse', action='store_true') 147 | parser.add_argument('--convert-type', type=str, default='indirect') 148 | parser.add_argument('--edge-dim', type=int, default=8) 149 | parser.add_argument('--neighbors', type=int, default=6) 150 | # optimizer 151 | parser.add_argument('--warmup-step', type=int, default=0) 152 | parser.add_argument('--gamma', type=float, default=0.1) 153 | parser.add_argument('--milestones', nargs='+', type=int, default=[10, 15]) 154 | parser.add_argument('--warmup-factor', type=float, default=0.01) 155 | parser.add_argument('--lr', type=float, default=0.00035, 156 | help="learning rate of new parameters, for pretrained " 157 | "parameters it is 10 times smaller than this") 158 | parser.add_argument('--momentum', type=float, default=0.9) 159 | parser.add_argument('--alpha', type=float, default=0.999) 160 | parser.add_argument('--weight-decay', type=float, default=5e-4) 161 | parser.add_argument('--epochs', type=int, default=40) 162 | # training configs 163 | parser.add_argument('--seed', type=int, default=1) 164 | parser.add_argument('--print-freq', type=int, default=10) 165 | parser.add_argument('--loss-name', type=str, default='MSE') 166 | parser.add_argument('--train-iters', type=int, default=100) 167 | parser.add_argument('--get-logdets', action='store_true') 168 | parser.add_argument('--loss-mid', action='store_true') 169 | parser.add_argument('--neighbor-pattern', type=str, default='nearest', choices=['nearest', 'pair', 'all']) 170 | parser.add_argument('--num-frames', type=int, default=50) 171 | parser.add_argument('--stride', type=int, default=25) 172 | # testing configs 173 | parser.add_argument('--modify', action='store_true') 174 | parser.add_argument('--model-pth', type=str, metavar='PATH', default=' ') 175 | # path 176 | working_dir = osp.dirname(osp.abspath(__file__)) 177 | parser.add_argument('--data-dir', type=str, metavar='PATH', 178 | default=osp.join(working_dir, '../data/react_clean')) 179 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 180 | default=osp.join(working_dir, 'logs')) 181 | 182 | args = parser.parse_args() 183 | if args.test: 184 | test() 185 | else: 186 | train(args) -------------------------------------------------------------------------------- /regnn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactmultimodalchallenge/baseline_react2024/0c3ce587d36f05b26b604eea89c5261544b72847/regnn/utils/__init__.py -------------------------------------------------------------------------------- /regnn/utils/compute_distance_fun.py: -------------------------------------------------------------------------------- 1 | from tslearn.metrics import dtw 2 | import numpy as np 3 | import torch 4 | 5 | def compute_distance(a, b): 6 | if not type(a) == np.ndarray: 7 | a = a.clone().detach().cpu().numpy() 8 | b = b.clone().detach().cpu().numpy() 9 | a = a.reshape(-1, 25) 10 | b = b.reshape(-1, 25) 11 | res = 0 12 | for st, ed, weight in [(0, 15, 1 / 15), (15, 17, 1), (17, 25, 1 / 8)]: 13 | res += weight * dtw(a[ : , st : ed], b[ : , st : ed]) 14 | return res 15 | 16 | 17 | def compute_total_distance(preds, targets): 18 | # preds: 10 750 25 19 | # targets: N 750 25 20 | if not type(preds) == np.ndarray: 21 | preds = preds.clone().detach().cpu().numpy() 22 | if not type(targets) == np.ndarray: 23 | targets = targets.clone().detach().cpu().numpy() 24 | 25 | dtw_lists = [] 26 | for i in range(preds.shape[0]): 27 | for j in range(targets.shape[0]): 28 | pred = preds[i] 29 | target = targets[j] 30 | res = 0. 31 | for st, ed, weight in [(0, 15, 1 / 15), (15, 17, 1), (17, 25, 1 / 8)]: 32 | res += weight * dtw(pred[ : , st : ed], target[ : , st : ed]) 33 | dtw_lists.append(res) 34 | 35 | return min(dtw_lists) 36 | -------------------------------------------------------------------------------- /regnn/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from numpy import array, zeros, full, argmin, inf, ndim 4 | from scipy.spatial.distance import cdist 5 | from torch.nn import functional as F 6 | from torchmetrics.functional import concordance_corrcoef 7 | 8 | def CCC(ground_truthes, predictions): 9 | length = predictions.shape[1] 10 | ccces = 0.0 11 | count = 0 12 | for i in range(length): 13 | ground_truth = ground_truthes[:, i].reshape(-1, 1) 14 | prediction = predictions[:, i].reshape(-1, 1) 15 | mean_gt = torch.mean(ground_truth) 16 | mean_pred = torch.mean(prediction) 17 | var_gt = torch.var(ground_truth) 18 | var_pred = torch.var(prediction) 19 | v_pred = prediction - mean_pred 20 | v_gt = ground_truth - mean_gt 21 | cor = torch.sum(v_pred * v_gt) / (torch.sqrt(torch.sum(v_pred ** 2)) * torch.sqrt(torch.sum(v_gt ** 2)) + 1e-8) 22 | sd_gt = torch.std(ground_truth) 23 | sd_pred = torch.std(prediction) 24 | numerator = 2 * cor * sd_gt * sd_pred 25 | denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2 26 | ccc = numerator / denominator 27 | ccces += torch.abs(ccc) 28 | if torch.abs(ccc) > 0.: 29 | count += 1 30 | 31 | return ccces / count 32 | 33 | def ccc(target, pred): 34 | # for i in range(25): 35 | # c = torch.abs(concordance_corrcoef(target[:, i], pred[:, i])) 36 | # print(c) 37 | # print(target[:, i], pred[:, i]) 38 | # print('='*10) 39 | return torch.nanmean(torch.abs(concordance_corrcoef(target, pred))) 40 | 41 | def person_r(labels, preds): 42 | length = labels.shape[1] 43 | assert length == 25, "Length != 25" 44 | pcc = 0. 45 | for i in range(length): 46 | label, pred = labels[:, i], preds[:, i] 47 | mean_label = torch.mean(label, dim=-1, keepdims=True) 48 | mean_pred = torch.mean(pred, dim=-1, keepdims=True) 49 | new_label, new_pred = label - mean_label, pred - mean_pred 50 | # label_norm = F.normalize(new_label, p=2, dim=1) 51 | # pred_norm = F.normalize(new_pred, p=2, dim=1) 52 | 53 | value = F.cosine_similarity(new_label, new_pred, dim=-1) 54 | pcc += value 55 | 56 | return pcc.mean() 57 | 58 | def pcc(x, y): 59 | assert x.shape == y.shape 60 | x = x.transpose(1, 0) 61 | y = y.transpose(1, 0) 62 | centered_x = x - x.mean(dim=-1, keepdim=True) 63 | centered_y = y - y.mean(dim=-1, keepdim=True) 64 | 65 | covariance = (centered_x * centered_y).sum(dim=-1, keepdim=True) 66 | 67 | bessel_corrected_covariance = covariance / (x.shape[-1] - 1) 68 | 69 | x_std = x.std(dim=-1, keepdim=True) 70 | y_std = y.std(dim=-1, keepdim=True) 71 | 72 | corr = bessel_corrected_covariance / ((x_std * y_std) + 1e-8) 73 | return torch.abs(corr).sum() / torch.nonzero(corr).shape[0] 74 | 75 | def accelerated_dtw(x, y, dist, warp=1): 76 | assert len(x) 77 | assert len(y) 78 | if np.ndim(x) == 1: 79 | x = x.reshape(-1, 1) 80 | if np.ndim(y) == 1: 81 | y = y.reshape(-1, 1) 82 | r, c = len(x), len(y) 83 | D0 = np.zeros((r + 1, c + 1)) 84 | D0[0, 1:] = inf 85 | D0[1:, 0] = inf 86 | D1 = D0[1:, 1:] 87 | D0[1:, 1:] = cdist(x, y, dist) 88 | C = D1.copy() 89 | for i in range(r): 90 | for j in range(c): 91 | min_list = [D0[i, j]] 92 | for k in range(1, warp + 1): 93 | min_list += [D0[min(i + k, r), j], 94 | D0[i, min(j + k, c)]] 95 | D1[i, j] += min(min_list) 96 | 97 | return D1[-1, -1] 98 | 99 | def tlcc(x, y, lag): 100 | """ 101 | 计算时间序列x和y之间的Time Lagged Cross-Correlation(TLCC)。 102 | 103 | Args: 104 | x: torch.Tensor,形状为(n,)的时间序列数据。 105 | y: torch.Tensor,形状为(n,)的时间序列数据。 106 | lag: int,时间滞后量。滞后量可以是正数(y滞后于x)或负数(y超前于x)。 107 | 108 | Returns: 109 | TLCC值。 110 | """ 111 | # 计算x和y的平均值 112 | mean_x = torch.mean(x) 113 | mean_y = torch.mean(y) 114 | 115 | # 计算x和y的标准差 116 | std_x = torch.std(x) 117 | std_y = torch.std(y) 118 | 119 | # 计算x和y的标准化时间序列 120 | x_norm = (x - mean_x) / std_x 121 | y_norm = (y - mean_y) / std_y 122 | 123 | # 将y_norm向右平移lag个单位 124 | if lag > 0: 125 | y_norm = torch.cat([torch.zeros(lag), y_norm[:-lag]]) 126 | elif lag < 0: 127 | y_norm = torch.cat([y_norm[-lag:], torch.zeros(-lag)]) 128 | 129 | # 计算x_norm和y_norm的TLCC 130 | tlcc_value = torch.corrcoef(x_norm, y_norm)[0, 1] 131 | return tlcc_value 132 | 133 | 134 | def s_mse(preds): 135 | # preds: (10, 750, 25) 136 | preds_ = preds.reshape(preds.shape[0], -1) 137 | dist = torch.pow(torch.cdist(preds_, preds_), 2) 138 | dist = torch.sum(dist) / (preds.shape[0] * (preds.shape[0] - 1)) 139 | return dist / preds_.shape[1] 140 | 141 | 142 | def FRVar(preds): 143 | if len(preds.shape) == 3: 144 | # preds: (10, 750, 25) 145 | var = torch.var(preds, dim=1) 146 | return torch.mean(var) 147 | elif len(preds.shape) == 4: 148 | # preds: (N, 10, 750, 25) 149 | var = torch.var(preds, dim=2) 150 | return torch.mean(var) 151 | 152 | 153 | def FRDvs(preds): 154 | # preds: (N, 10, 750, 25) 155 | preds_ = preds.reshape(preds.shape[0], preds.shape[1], -1) 156 | preds_ = preds_.transpose(0, 1) 157 | # preds_: (10, N, 750*25) 158 | dist = torch.pow(torch.cdist(preds_, preds_), 2) 159 | # dist: (10, N, N) 160 | dist = torch.sum(dist) / (preds.shape[0] * (preds.shape[0] - 1) * preds.shape[1]) 161 | return dist / preds_.shape[-1] 162 | 163 | 164 | def crosscorr(datax, datay, lag=0, dim=25): 165 | pcc_list = [] 166 | for i in range(dim): 167 | # print(datax.shape, datay.shape) 168 | cn_1, cn_2 = shift(datax[:, i], datay[:, i], lag) 169 | pcc_i = torch.corrcoef(torch.stack([cn_1, cn_2], dim=0))[0, 1] 170 | pcc_list.append(pcc_i.item()) 171 | return torch.mean(torch.Tensor(pcc_list)) 172 | 173 | 174 | def calculate_tlcc(pred, sp, seconds=2, fps=25): 175 | rs = [crosscorr(pred, sp, lag, sp.shape[-1]) for lag in range(-int(seconds * fps - 1), int(seconds * fps))] 176 | peak = max(rs) 177 | center = rs[len(rs) // 2] 178 | offset = len(rs) // 2 - torch.argmax(torch.Tensor(rs)) 179 | return peak, center, offset 180 | 181 | def TLCC(pred, speaker): 182 | # pred: N 10 750 25 183 | # speaker: N 750 25 184 | offset_list = [] 185 | for k in range(speaker.shape[0]): 186 | pred_item = pred[k] 187 | sp_item = speaker[k] 188 | for i in range(pred_item.shape[0]): 189 | peak, center, offset = calculate_tlcc(pred_item[i].float(), sp_item.float()) 190 | offset_list.append(torch.abs(offset).item()) 191 | return torch.mean(torch.Tensor(offset_list)).item() 192 | 193 | 194 | def SingleTLCC(pred, speaker): 195 | # pred: 10 750 25 196 | # speaker: 750 25 197 | offset_list = [] 198 | for i in range(pred.shape[0]): 199 | peak, center, offset = calculate_tlcc(pred[i].float(), speaker.float()) 200 | offset_list.append(torch.abs(offset).item()) 201 | return torch.mean(torch.Tensor(offset_list)).item() 202 | 203 | 204 | def shift(x, y, lag): 205 | if lag > 0: 206 | return x[lag:], y[:-lag] 207 | elif lag < 0: 208 | return x[:lag], y[-lag:] 209 | else: 210 | return x, y 211 | 212 | 213 | import pandas as pd 214 | 215 | 216 | def shift_elements(arr, num, fill_value=np.nan): 217 | result = np.empty_like(arr) 218 | if num > 0: 219 | result[:num] = fill_value 220 | result[num:] = arr[:-num] 221 | elif num < 0: 222 | result[num:] = fill_value 223 | result[:num] = arr[-num:] 224 | else: 225 | result[:] = arr 226 | return result 227 | 228 | 229 | def crosscorr_(datax, datay, lag=0, dim=25, wrap=False): 230 | pcc_list = [] 231 | for i in range(25): 232 | cn_1 = pd.Series(datax[:, i].reshape(-1)) 233 | cn_2 = pd.Series(datay[:, i].reshape(-1)) 234 | pcc_i = cn_1.corr(cn_2.shift(lag)) 235 | pcc_list.append(pcc_i) 236 | return np.mean(pcc_list) 237 | 238 | 239 | def calculate_tlcc_(pred, sp, seconds=2, fps=25): 240 | rs = [crosscorr_(pred, sp, lag, sp.shape[-1]) for lag in range(-int(seconds * fps - 1), int(seconds * fps))] 241 | peak = max(rs) 242 | center = rs[len(rs) // 2] 243 | offset = len(rs) // 2 - np.argmax(rs) 244 | return peak, center, offset 245 | 246 | 247 | def compute_TLCC(pred, speaker): 248 | offset_list = [] 249 | for k in range(speaker.shape[0]): 250 | pred_item = pred[k] 251 | sp_item = speaker[k] 252 | for i in range(pred_item.shape[0]): 253 | peak, center, offset = calculate_tlcc_(pred_item[i].numpy().astype(np.float32), 254 | sp_item.numpy().astype(np.float32)) 255 | offset_list.append(np.abs(offset)) 256 | return np.mean(offset_list) 257 | 258 | 259 | def compute_singel_tlcc(pred, speaker): 260 | pred = pred.detach().cpu().numpy() 261 | speaker = speaker.detach().cpu().numpy() 262 | peak, center, offset = calculate_tlcc_(pred.astype(np.float32), 263 | speaker.astype(np.float32)) 264 | 265 | return np.abs(offset) 266 | 267 | 268 | 269 | 270 | -------------------------------------------------------------------------------- /regnn/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mkdir_if_not_exist(pth): 4 | if not os.path.isdir(pth): 5 | os.makedirs(pth) -------------------------------------------------------------------------------- /regnn/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import errno 5 | 6 | 7 | def mkdir_if_missing(dir_path): 8 | try: 9 | os.makedirs(dir_path) 10 | except OSError as e: 11 | if e.errno != errno.EEXIST: 12 | raise 13 | 14 | 15 | class Logger(object): 16 | def __init__(self, fpath=None): 17 | self.console = sys.stdout 18 | self.file = None 19 | if fpath is not None: 20 | mkdir_if_missing(os.path.dirname(fpath)) 21 | self.file = open(fpath, 'w') 22 | 23 | def __del__(self): 24 | self.close() 25 | 26 | def __enter__(self): 27 | pass 28 | 29 | def __exit__(self, *args): 30 | self.close() 31 | 32 | def write(self, msg): 33 | self.console.write(msg) 34 | if self.file is not None: 35 | self.file.write(msg) 36 | 37 | def flush(self): 38 | self.console.flush() 39 | if self.file is not None: 40 | self.file.flush() 41 | os.fsync(self.file.fileno()) 42 | 43 | def close(self): 44 | self.console.close() 45 | if self.file is not None: 46 | self.file.close() 47 | -------------------------------------------------------------------------------- /regnn/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.distributions.multivariate_normal import MultivariateNormal 6 | # from sdtw.distance import SquaredEuclidean 7 | # from .softdtw import SoftDTW 8 | from utils.evaluate import person_r 9 | from torchmetrics.functional import pearson_corrcoef 10 | 11 | class PCC(nn.Module): 12 | def __init__(self, ): 13 | super(PCC, self).__init__() 14 | 15 | def forward(self, labels, preds): 16 | batch_size = labels.shape[0] 17 | pcc = 0 18 | for i in range(batch_size): 19 | pcc -= person_r(labels[i], preds[i]) 20 | 21 | return pcc / batch_size 22 | 23 | class PearsonCC(nn.Module): 24 | def __init__(self): 25 | super(PearsonCC, self).__init__() 26 | 27 | def forward(self, preds, targets): 28 | batch_size = preds.shape[0] 29 | pcces = 0. 30 | for i in range(batch_size): 31 | pcces += (1 - torch.nanmean(torch.abs(pearson_corrcoef(preds[i], targets[i])))) 32 | 33 | return pcces / batch_size 34 | 35 | class DistributionLoss(nn.Module): 36 | def __init__(self, dis_type='Gaussian'): 37 | super().__init__() 38 | self.dis_type = dis_type 39 | 40 | def cal_MultiGaussian(self, feature, means, var): 41 | # var = var / var.sum() 42 | k = means.shape[0] 43 | # V = torch.eye(var.shape[0], device='cuda:0') + var @ var.transpose(1, 0) 44 | V = torch.eye(var.shape[0], device='cuda:0') 45 | log_probs = 0. 46 | for i in range(k): 47 | distribution = MultivariateNormal(loc=means[i], covariance_matrix=V) 48 | log_prob = distribution.log_prob(feature[i] if len(feature.shape) == 2 else feature[:, i]) 49 | log_probs -= log_prob 50 | 51 | return log_probs 52 | 53 | def forward(self, latent_feature, params): 54 | if self.dis_type == 'MultiGaussian': 55 | means, vars = params 56 | vars = torch.diag_embed(vars) 57 | dist = MultivariateNormal(means, vars) 58 | log_probs = dist.log_prob(latent_feature) 59 | return -torch.sum(log_probs) 60 | 61 | elif self.dis_type == 'Gaussian': 62 | mean, var = params 63 | n_elements, frame_size = latent_feature.shape[1], latent_feature.shape[2] 64 | distribution = MultivariateNormal(loc=mean, covariance_matrix=var) 65 | # loss = -distribution.log_prob(latent_feature) 66 | latent_feature = latent_feature.reshape(-1, n_elements * frame_size) 67 | loss = torch.mean((latent_feature - mean)**2 )\ 68 | # / (var**2 + 1e-6) 69 | # ) 70 | loss_2 = torch.exp(distribution.log_prob(latent_feature)) 71 | # loss = torch.exp(loss) 72 | return loss, loss_2 73 | 74 | elif self.dis_type == 'Gmm': 75 | mean, var, muk = params 76 | log_probs = 0.0 77 | return log_probs 78 | 79 | 80 | class AllMseLoss(nn.Module): 81 | def __init__(self, ): 82 | super(AllMseLoss, self).__init__() 83 | 84 | def forward(self, preds, targets, lengthes): 85 | preds = torch.repeat_interleave(preds, dim=0, repeats=torch.tensor(lengthes, device=preds.device)) 86 | assert preds.shape == targets.shape, "Not equal shape" 87 | return F.mse_loss(preds, targets) 88 | 89 | 90 | class ThreMseLoss(nn.Module): 91 | def __init__(self, ): 92 | super(ThreMseLoss, self).__init__() 93 | 94 | def forward(self, preds, targets, lengthes, threshold): 95 | batch = preds.shape[0] 96 | if not lengthes: 97 | all_mse = F.mse_loss(preds, targets, reduction='none').mean(dim=(1, 2)) 98 | if threshold is not None: 99 | mask = torch.where(torch.abs(all_mse) >= threshold, 100 | torch.tensor(1, device=preds.device), torch.tensor(0, device=preds.device)) 101 | all_mse = all_mse * mask.float() 102 | 103 | return all_mse.mean() 104 | 105 | preds = torch.repeat_interleave(preds, dim=0, repeats=torch.tensor(lengthes, device=preds.device)) 106 | all_mse = F.mse_loss(preds, targets, reduction='none').mean(dim=(1, 2)) 107 | results = 0. 108 | start = 0 109 | for length in lengthes: 110 | cur_mse = torch.min(all_mse[start:start+length]) 111 | if threshold is None or cur_mse > threshold: 112 | results += cur_mse 113 | 114 | start += length 115 | return results / batch 116 | 117 | # THis loss is for training the cognitive processor 118 | class AllThreMseLoss(nn.Module): 119 | def __init__(self, cal_type='min'): 120 | super(AllThreMseLoss, self).__init__() 121 | assert cal_type in {'min', 'all'} 122 | self.cal_type = cal_type 123 | 124 | def forward(self, preds, targets, lengths, threshold): 125 | # 将预测结果张量重复多次 126 | preds = torch.repeat_interleave(preds, dim=0, repeats=torch.tensor(lengths, device=preds.device)) 127 | 128 | # 计算每个样本与目标结果之间的均方误差 129 | # print(preds) 130 | # print(targets) 131 | all_mse = F.mse_loss(preds, targets, reduction='none').mean(dim=(1, 2)) 132 | 133 | # 将所有样本的均方误差按照长度拆分成多个小张量,并计算每个小张量中的最小值 134 | if self.cal_type == 'min': 135 | seqs = torch.split(all_mse, lengths) 136 | mse = torch.stack([torch.min(seq) for seq in seqs]) 137 | else: 138 | mse = all_mse 139 | 140 | # 根据阈值筛选出符合条件的样本,并计算所有符合条件的样本的均方误差的平均值 141 | if threshold is not None: 142 | mse = torch.where(mse > threshold, mse, 143 | torch.tensor([0.0], dtype=torch.float, device=mse.device) 144 | ) 145 | 146 | if self.cal_type == 'min': 147 | return torch.mean(mse) 148 | else: 149 | seqs = torch.split(mse, lengths) 150 | mse = torch.stack([torch.mean(seq) for seq in seqs]) 151 | return torch.mean(mse) 152 | 153 | 154 | # def forward(self, preds, targets, lengthes, threshold): 155 | # batch = preds.shape[0] 156 | # preds = torch.repeat_interleave(preds, dim=0, repeats=torch.tensor(lengthes, device=preds.device)) 157 | # assert preds.shape == targets.shape, "Not equal shape" 158 | # 159 | # all_mse = F.mse_loss(preds, targets, reduction='none').mean(dim=(1, 2)) 160 | # seqs = torch.split(all_mse, lengthes) 161 | # min_mse = torch.stack([torch.min(seq, dim=0) for seq in seqs]) 162 | # results = 0. 163 | # start = 0 164 | # for length in lengthes: 165 | # cur_mse = torch.min(all_mse[start:start+length]) 166 | # if threshold is None or cur_mse > threshold: 167 | # results += cur_mse 168 | # 169 | # start += length 170 | # return results / batch 171 | 172 | # This loss is for training the REGNN. It is self-supervised. 173 | class MidLoss(nn.Module): 174 | def __init__(self, loss_type='L2'): 175 | super(MidLoss, self).__init__() 176 | if loss_type == 'L2': 177 | self.loss = nn.MSELoss() 178 | elif loss_type == 'L1': 179 | self.loss = nn.L1Loss() 180 | else: 181 | raise AttributeError 182 | 183 | def forward(self, inputs, lengths): 184 | # 将输入张量按照长度拆分成多个小张量 185 | seqs = torch.split(inputs, lengths) 186 | 187 | # 计算每个小张量的平均值,并将结果沿着第1个维度堆叠起来 188 | means = torch.stack([torch.mean(seq, dim=0) for seq in seqs]) 189 | 190 | means = torch.repeat_interleave(means, dim=0, repeats=torch.tensor(lengths, device=means.device)) 191 | # 将平均值张量扩展成与输入张量相同的形状,并计算损失函数值 192 | loss = self.loss(means, inputs) 193 | 194 | # 计算平均损失函数值 195 | return loss.mean() 196 | 197 | # start = 0 198 | # results = 0. 199 | # batch_size = len(lengthes) 200 | # for length in lengthes: 201 | # sub_inputs = inputs[start:start+length] 202 | # mean_val = torch.mean(sub_inputs, dim=1, keepdim=True) 203 | # start += length 204 | # 205 | # results += self.loss(mean_val.expand_as(sub_inputs), sub_inputs) 206 | # 207 | # return results / batch_size 208 | 209 | class S_MSE(nn.Module): 210 | def __init__(self, loss_type='L2'): 211 | super(S_MSE, self).__init__() 212 | self.loss_type = loss_type 213 | 214 | def forward(self, inputs, lengths): 215 | # 将输入张量按照长度拆分成多个小张量 216 | seqs = torch.split(inputs, lengths) 217 | results = [] 218 | for seq in seqs: 219 | results.append(self.cal(seq)) 220 | 221 | results = torch.stack(results, 0) 222 | 223 | return torch.nanmean(results) 224 | 225 | def cal(self, data): 226 | data_ = data.reshape(data.shape[0], -1) 227 | if self.loss_type == 'L1': 228 | dist = torch.pow(torch.cdist(data_, data_), 1) 229 | else: 230 | dist = torch.pow(torch.cdist(data_, data_), 2) 231 | dist = torch.sum(dist) / (data.shape[0] * (data.shape[0] - 1)) 232 | return dist / data_.shape[1] 233 | 234 | -------------------------------------------------------------------------------- /regnn/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | from torch.optim.lr_scheduler import * 9 | 10 | 11 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 12 | # separating MultiStepLR with WarmupLR 13 | # but the current LRScheduler design doesn't allow it 14 | 15 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 16 | def __init__( 17 | self, 18 | optimizer, 19 | milestones=None, 20 | gamma=0.1, 21 | warmup_factor=1.0 / 3, 22 | warmup_iters=500, 23 | warmup_method="linear", 24 | last_epoch=-1, 25 | ): 26 | # if not list(milestones) == sorted(milestones): 27 | # raise ValueError( 28 | # "Milestones should be a list of" " increasing integers. Got {}", 29 | # milestones, 30 | # ) 31 | 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.gamma = gamma 39 | self.warmup_factor = warmup_factor 40 | self.warmup_iters = warmup_iters 41 | self.warmup_method = warmup_method 42 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 43 | 44 | def get_lr(self): 45 | warmup_factor = 1 46 | if self.last_epoch < self.warmup_iters: 47 | if self.warmup_method == "constant": 48 | warmup_factor = self.warmup_factor 49 | elif self.warmup_method == "linear": 50 | alpha = float(self.last_epoch) / float(self.warmup_iters) 51 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 52 | return [ 53 | base_lr 54 | * warmup_factor 55 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 56 | for base_lr in self.base_lrs 57 | ] 58 | 59 | 60 | import logging 61 | import math 62 | import torch 63 | 64 | # from .scheduler import Scheduler 65 | 66 | 67 | _logger = logging.getLogger(__name__) 68 | 69 | 70 | class CosineLRScheduler(torch.optim.lr_scheduler._LRScheduler): 71 | """ 72 | Cosine decay with restarts. 73 | This is described in the paper https://arxiv.org/abs/1608.03983. 74 | Inspiration from 75 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 76 | """ 77 | 78 | def __init__(self, 79 | optimizer: torch.optim.Optimizer, 80 | t_initial: int, 81 | t_mul: float = 1., 82 | lr_min: float = 0., 83 | decay_rate: float = 1., 84 | warmup_t=0, 85 | warmup_lr_init=0, 86 | warmup_prefix=False, 87 | cycle_limit=0, 88 | t_in_epochs=True, 89 | noise_range_t=None, 90 | noise_pct=0.67, 91 | noise_std=1.0, 92 | noise_seed=42, 93 | initialize=True) -> None: 94 | super().__init__( 95 | optimizer, param_group_field="lr", 96 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 97 | initialize=initialize) 98 | 99 | assert t_initial > 0 100 | assert lr_min >= 0 101 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 102 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 103 | "rate since t_initial = t_mul = eta_mul = 1.") 104 | self.t_initial = t_initial 105 | self.t_mul = t_mul 106 | self.lr_min = lr_min 107 | self.decay_rate = decay_rate 108 | self.cycle_limit = cycle_limit 109 | self.warmup_t = warmup_t 110 | self.warmup_lr_init = warmup_lr_init 111 | self.warmup_prefix = warmup_prefix 112 | self.t_in_epochs = t_in_epochs 113 | if self.warmup_t: 114 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 115 | super().update_groups(self.warmup_lr_init) 116 | else: 117 | self.warmup_steps = [1 for _ in self.base_values] 118 | 119 | def _get_lr(self, t): 120 | if t < self.warmup_t: 121 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 122 | else: 123 | if self.warmup_prefix: 124 | t = t - self.warmup_t 125 | 126 | if self.t_mul != 1: 127 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 128 | t_i = self.t_mul ** i * self.t_initial 129 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 130 | else: 131 | i = t // self.t_initial 132 | t_i = self.t_initial 133 | t_curr = t - (self.t_initial * i) 134 | 135 | gamma = self.decay_rate ** i 136 | lr_min = self.lr_min * gamma 137 | lr_max_values = [v * gamma for v in self.base_values] 138 | 139 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 140 | lrs = [ 141 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 142 | ] 143 | else: 144 | lrs = [self.lr_min for _ in self.base_values] 145 | 146 | return lrs 147 | 148 | def get_epoch_values(self, epoch: int): 149 | if self.t_in_epochs: 150 | return self._get_lr(epoch) 151 | else: 152 | return None 153 | 154 | def get_update_values(self, num_updates: int): 155 | if not self.t_in_epochs: 156 | return self._get_lr(num_updates) 157 | else: 158 | return None 159 | 160 | def get_cycle_length(self, cycles=0): 161 | if not cycles: 162 | cycles = self.cycle_limit 163 | cycles = max(1, cycles) 164 | if self.t_mul == 1.0: 165 | return self.t_initial * cycles 166 | else: 167 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 168 | -------------------------------------------------------------------------------- /regnn/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms 6 | from skimage.io import imsave 7 | import skvideo.io 8 | from pathlib import Path 9 | from tqdm import auto 10 | import argparse 11 | import cv2 12 | 13 | from utils import torch_img_to_np, _fix_image, torch_img_to_np2 14 | from external.FaceVerse import get_faceverse 15 | from external.PIRender import FaceGenerator 16 | 17 | 18 | def obtain_seq_index(index, num_frames, semantic_radius = 13): 19 | seq = list(range(index - semantic_radius, index + semantic_radius + 1)) 20 | seq = [min(max(item, 0), num_frames - 1) for item in seq] 21 | return seq 22 | 23 | 24 | def transform_semantic(semantic): 25 | semantic_list = [] 26 | for i in range(semantic.shape[0]): 27 | index = obtain_seq_index(i, semantic.shape[0]) 28 | semantic_item = semantic[index, :].unsqueeze(0) 29 | semantic_list.append(semantic_item) 30 | semantic = torch.cat(semantic_list, dim = 0) 31 | return semantic.transpose(1,2) 32 | 33 | 34 | 35 | class Render(object): 36 | """Computes and stores the average and current value""" 37 | 38 | def __init__(self, device = 'cpu'): 39 | self.faceverse, _ = get_faceverse(device=device, img_size=224) 40 | self.faceverse.init_coeff_tensors() 41 | self.id_tensor = torch.from_numpy(np.load('external/FaceVerse/reference_full.npy')).float().view(1,-1)[:,:150] 42 | self.pi_render = FaceGenerator().to(device) 43 | self.pi_render.eval() 44 | checkpoint = torch.load('external/PIRender/cur_model_fold.pth') 45 | self.pi_render.load_state_dict(checkpoint['state_dict']) 46 | 47 | self.mean_face = torch.FloatTensor( 48 | np.load('external/FaceVerse/mean_face.npy').astype(np.float32)).view(1, 1, -1).to(device) 49 | self.std_face = torch.FloatTensor( 50 | np.load('external/FaceVerse/std_face.npy').astype(np.float32)).view(1, 1, -1).to(device) 51 | 52 | self._reverse_transform_3dmm = transforms.Lambda(lambda e: e + self.mean_face) 53 | 54 | def rendering(self, path, ind, listener_vectors, speaker_video_clip, listener_reference): 55 | 56 | # 3D video 57 | T = listener_vectors.shape[0] 58 | listener_vectors = self._reverse_transform_3dmm(listener_vectors)[0] 59 | 60 | self.faceverse.batch_size = T 61 | self.faceverse.init_coeff_tensors() 62 | 63 | self.faceverse.exp_tensor = listener_vectors[:,:52].view(T,-1).to(listener_vectors.get_device()) 64 | self.faceverse.rot_tensor = listener_vectors[:,52:55].view(T, -1).to(listener_vectors.get_device()) 65 | self.faceverse.trans_tensor = listener_vectors[:,55:].view(T, -1).to(listener_vectors.get_device()) 66 | self.faceverse.id_tensor = self.id_tensor.view(1,150).repeat(T,1).view(T,150).to(listener_vectors.get_device()) 67 | 68 | 69 | pred_dict = self.faceverse(self.faceverse.get_packed_tensors(), render=True, texture=False) 70 | rendered_img_r = pred_dict['rendered_img'] 71 | rendered_img_r = np.clip(rendered_img_r.cpu().numpy(), 0, 255) 72 | rendered_img_r = rendered_img_r[:, :, :, :3].astype(np.uint8) 73 | 74 | 75 | # 2D video 76 | # listener_vectors = torch.cat((listener_exp.view(T,-1), listener_trans.view(T, -1), listener_rot.view(T, -1))) 77 | semantics = transform_semantic(listener_vectors.detach()).to(listener_vectors.get_device()) 78 | C, H, W = listener_reference.shape 79 | output_dict_list = [] 80 | duration = listener_vectors.shape[0] // 20 81 | listener_reference_frames = listener_reference.repeat(listener_vectors.shape[0], 1, 1).view( 82 | listener_vectors.shape[0], C, H, W) 83 | 84 | for i in range(20): 85 | if i != 19: 86 | listener_reference_copy = listener_reference_frames[i * duration:(i + 1) * duration] 87 | semantics_copy = semantics[i * duration:(i + 1) * duration] 88 | else: 89 | listener_reference_copy = listener_reference_frames[i * duration:] 90 | semantics_copy = semantics[i * duration:] 91 | output_dict = self.pi_render(listener_reference_copy, semantics_copy) 92 | fake_videos = output_dict['fake_image'] 93 | fake_videos = torch_img_to_np2(fake_videos) 94 | output_dict_list.append(fake_videos) 95 | 96 | listener_videos = np.concatenate(output_dict_list, axis=0) 97 | speaker_video_clip = torch_img_to_np2(speaker_video_clip) 98 | 99 | out = cv2.VideoWriter(os.path.join(path, ind + "_val.avi"), cv2.VideoWriter_fourcc(*"MJPG"), 25, (672, 224)) 100 | for i in range(rendered_img_r.shape[0]): 101 | combined_img = np.zeros((224, 672, 3), dtype=np.uint8) 102 | combined_img[0:224, 0:224] = speaker_video_clip[i] 103 | combined_img[0:224, 224:448] = rendered_img_r[i] 104 | combined_img[0:224, 448:] = listener_videos[i] 105 | out.write(combined_img) 106 | out.release() 107 | 108 | 109 | 110 | def rendering_for_fid(self, path, ind, listener_vectors, speaker_video_clip, listener_reference, listener_video_clip): 111 | # 3D video 112 | T = listener_vectors.shape[0] 113 | listener_vectors = self._reverse_transform_3dmm(listener_vectors)[0] 114 | 115 | self.faceverse.batch_size = T 116 | self.faceverse.init_coeff_tensors() 117 | 118 | self.faceverse.exp_tensor = listener_vectors[:, :52].view(T, -1).to(listener_vectors.get_device()) 119 | self.faceverse.rot_tensor = listener_vectors[:, 52:55].view(T, -1).to(listener_vectors.get_device()) 120 | self.faceverse.trans_tensor = listener_vectors[:, 55:].view(T, -1).to(listener_vectors.get_device()) 121 | self.faceverse.id_tensor = self.id_tensor.view(1, 150).repeat(T, 1).view(T, 150).to(listener_vectors.get_device()) 122 | 123 | pred_dict = self.faceverse(self.faceverse.get_packed_tensors(), render=True, texture=False) 124 | rendered_img_r = pred_dict['rendered_img'] 125 | rendered_img_r = np.clip(rendered_img_r.cpu().numpy(), 0, 255) 126 | rendered_img_r = rendered_img_r[:, :, :, :3].astype(np.uint8) 127 | 128 | # 2D video 129 | # listener_vectors = torch.cat((listener_exp.view(T,-1), listener_trans.view(T, -1), listener_rot.view(T, -1))) 130 | semantics = transform_semantic(listener_vectors.detach()).to(listener_vectors.get_device()) 131 | C, H, W = listener_reference.shape 132 | output_dict_list = [] 133 | duration = listener_vectors.shape[0] // 20 134 | listener_reference_frames = listener_reference.repeat(listener_vectors.shape[0], 1, 1).view( 135 | listener_vectors.shape[0], C, H, W) 136 | 137 | for i in range(20): 138 | if i != 19: 139 | listener_reference_copy = listener_reference_frames[i * duration:(i + 1) * duration] 140 | semantics_copy = semantics[i * duration:(i + 1) * duration] 141 | else: 142 | listener_reference_copy = listener_reference_frames[i * duration:] 143 | semantics_copy = semantics[i * duration:] 144 | output_dict = self.pi_render(listener_reference_copy, semantics_copy) 145 | fake_videos = output_dict['fake_image'] 146 | fake_videos = torch_img_to_np2(fake_videos) 147 | output_dict_list.append(fake_videos) 148 | 149 | listener_videos = np.concatenate(output_dict_list, axis=0) 150 | speaker_video_clip = torch_img_to_np2(speaker_video_clip) 151 | 152 | if not os.path.exists(os.path.join(path, 'results_videos')): 153 | os.makedirs(os.path.join(path, 'results_videos')) 154 | out = cv2.VideoWriter(os.path.join(path, 'results_videos', ind + "_val.avi"), cv2.VideoWriter_fourcc(*"MJPG"), 25, (672, 224)) 155 | for i in range(rendered_img_r.shape[0]): 156 | combined_img = np.zeros((224, 672, 3), dtype=np.uint8) 157 | combined_img[0:224, 0:224] = speaker_video_clip[i] 158 | combined_img[0:224, 224:448] = rendered_img_r[i] 159 | combined_img[0:224, 448:] = listener_videos[i] 160 | out.write(combined_img) 161 | out.release() 162 | 163 | listener_video_clip = torch_img_to_np2(listener_video_clip) 164 | 165 | path_real = os.path.join(path, 'fid', 'real') 166 | if not os.path.exists(path_real): 167 | os.makedirs(path_real) 168 | path_fake = os.path.join(path, 'fid', 'fake') 169 | if not os.path.exists(path_fake): 170 | os.makedirs(path_fake) 171 | 172 | for i in range(0, rendered_img_r.shape[0], 30): 173 | 174 | cv2.imwrite(os.path.join(path_fake, ind+'_'+str(i+1)+'.png'), listener_videos[i]) 175 | cv2.imwrite(os.path.join(path_real, ind+'_'+str(i+1)+'.png'), listener_video_clip[i]) 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv_contrib_python 3 | opencv_python 4 | opencv_python_headless 5 | pandas 6 | pillow 7 | scikit_image 8 | scikit_video 9 | scipy 10 | soundfile 11 | torch 12 | torchaudio 13 | torchvision 14 | tqdm 15 | tslearn 16 | omegaconf 17 | pytorch-fid 18 | einops 19 | -------------------------------------------------------------------------------- /run_baselines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | from tqdm import tqdm 5 | from dataset import get_dataloader 6 | from metric import * 7 | import multiprocessing as mp 8 | 9 | baselines = ["GT", 10 | "Random", "Mime", 11 | "MeanSeq", "MeanFr", ] 12 | training_mean = None 13 | trainin_mean_single = None 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Running baselines') 17 | # Param 18 | parser.add_argument('--dataset-path', default="/home/luocheng/Datasets/S-L", type=str, help="dataset path") 19 | parser.add_argument('--split', default="val", type=str, help="split of dataset", choices=["val", "test"]) 20 | parser.add_argument('-b', '--batch-size', default=4, type=int, metavar='N', help='mini-batch size (default: 4)') 21 | parser.add_argument('-j', '--num_workers', default=8, type=int, metavar='N', 22 | help='number of data loading workers (default: 8)') 23 | parser.add_argument('--weight-decay', '-wd', default=5e-4, type=float, metavar='W', 24 | help='weight decay (default: 1e-4)') 25 | parser.add_argument('--img-size', default=256, type=int, help="size of train/test image data") 26 | parser.add_argument('--crop-size', default=224, type=int, help="crop size of train/test image data") 27 | parser.add_argument('-max-seq-len', default=751, type=int, help="max length of clip") 28 | parser.add_argument('--clip-length', default=751, type=int, help="len of video clip") 29 | args = parser.parse_args() 30 | return args 31 | 32 | def get_baseline(cfg, baseline, num_pred=10, speaker_emotion=None, listener_emotion=None): 33 | batch_size = speaker_emotion.shape[0] 34 | if baseline == "MeanSeq" or baseline == "MeanFr": 35 | # This baseline predicts the sequence/frame mean of the training data for each emotion dimension. 36 | global training_mean, training_mean_single 37 | if training_mean is None or training_mean_single is None: 38 | train_loader = get_dataloader(cfg, "train", load_emotion_s=True, load_emotion_l=True) 39 | train_loader._split = "val" # to avoid data augmentation 40 | all_tr_emotion_list = [] 41 | for batch_idx, (_, _, speaker_emotion, _, _, _, listener_emotion, _, _) in enumerate(tqdm(train_loader)): 42 | all_tr_emotion_list.append(speaker_emotion.cpu()) 43 | all_tr_emotion_list.append(listener_emotion.cpu()) 44 | 45 | all_tr_emotion = torch.cat(all_tr_emotion_list, dim = 0) 46 | # average over all training data 47 | all_tr_emotion = all_tr_emotion.mean(dim=0) 48 | single_tr_emotion = all_tr_emotion.mean(dim=0) 49 | # repeat to match the number of predictions 50 | training_mean = all_tr_emotion[None, None, ...].repeat(batch_size, num_pred, 1, 1) 51 | training_mean_single = single_tr_emotion[None, None, ...].repeat(batch_size, num_pred, training_mean.shape[2], 1) 52 | 53 | return training_mean[:batch_size] if baseline == "MeanSeq" else training_mean_single[:batch_size] 54 | elif baseline == "Random": 55 | # predict listener emotion as random values between 0 and 1 56 | return torch.rand(batch_size, num_pred, *speaker_emotion.shape[1:]) 57 | elif baseline == "Mime": 58 | # predict listener emotion as speaker emotion (mime) 59 | return speaker_emotion[:, None, ...].repeat(1, num_pred, 1, 1) 60 | elif baseline == "GT": 61 | # predict listener emotion as ground truth 62 | return listener_emotion[:, None, ...].repeat(1, num_pred, 1, 1) 63 | else: 64 | raise NotImplementedError("Baseline {} not implemented".format(baseline)) 65 | 66 | 67 | # Train 68 | def val(cfg): 69 | assert cfg.split in ["val", "test"], "split must be in [val, test]" 70 | dataloader = get_dataloader(cfg, cfg.split, load_emotion_s=True, load_emotion_l=True, load_audio=False, load_video_s=False, load_video_l=False, load_3dmm_s=False, load_3dmm_l=False, load_ref=False) 71 | print("Dataset: {}, Split: {}, Baselines: {}".format(cfg.dataset_path, cfg.split, baselines)) 72 | print("Number of samples: {}".format(len(dataloader.dataset))) 73 | 74 | for i, baseline in enumerate(baselines): 75 | listener_emotion_pred_list = [] 76 | listener_emotion_gt_list = [] 77 | speaker_emotion_list = [] 78 | for batch_idx, (_, _, speaker_emotion, _, _, _, listener_emotion, _, _) in enumerate(tqdm(dataloader)): 79 | 80 | prediction = get_baseline(cfg, baseline, num_pred=10, speaker_emotion=speaker_emotion, listener_emotion=listener_emotion) 81 | 82 | listener_emotion_pred_list.append(prediction) 83 | listener_emotion_gt_list.append(listener_emotion) 84 | speaker_emotion_list.append(speaker_emotion) 85 | 86 | all_pred_listener_emotion = torch.cat(listener_emotion_pred_list, dim = 0) 87 | all_speaker_emotion = torch.cat(speaker_emotion_list, dim = 0) 88 | all_listener_gt_emotion = torch.cat(listener_emotion_gt_list, dim = 0) 89 | 90 | 91 | assert all_speaker_emotion.shape[0] == all_pred_listener_emotion.shape[0], "Number of predictions and number of speaker emotions must match ({} vs. {})".format(all_pred_listener_emotion.shape[0], all_speaker_emotion.shape[0]) 92 | #print("-----------------Evaluating Metric-----------------") 93 | 94 | p=64 95 | np.seterr(divide='ignore', invalid='ignore') 96 | # If you have problems running function compute_FRC_mp, please replace this function with function compute_FRC 97 | FRC = compute_FRC_mp(cfg, all_pred_listener_emotion, all_listener_gt_emotion, p=p, val_test=cfg.split) 98 | #FRC = compute_FRC(cfg, all_pred_listener_emotion, all_listener_gt_emotion, val_test=cfg.split) 99 | 100 | # If you have problems running function compute_FRD_mp, please replace this function with function compute_FRD 101 | FRD = compute_FRD_mp(cfg, all_pred_listener_emotion, all_listener_gt_emotion, p=p, val_test=cfg.split) 102 | #FRD = compute_FRD(cfg, all_pred_listener_emotion, all_listener_gt_emotion, val_test=cfg.split) 103 | 104 | FRDvs = compute_FRDvs(all_pred_listener_emotion) 105 | FRVar = compute_FRVar(all_pred_listener_emotion) 106 | smse = compute_s_mse(all_pred_listener_emotion) 107 | TLCC = compute_TLCC_mp(all_pred_listener_emotion, all_speaker_emotion, p=p) 108 | 109 | # print all results in one line 110 | print("[{}/{}] Baseline: {}, FRC: {:.5f} | FRD: {:.5f} | S-MSE: {:.5f} | FRVar: {:.5f} | FRDvs: {:.5f} | TLCC: {:.5f}".format(i+1, 111 | len(baselines), 112 | baseline, 113 | FRC, 114 | FRD, 115 | smse, 116 | FRVar, 117 | FRDvs, 118 | TLCC)) 119 | 120 | print("Latex-friendly --> B\\_{} & {:.2f} & {:.2f} & {:.4f} & {:.4f} & {:.4f} & - & {:.2f} \\\\".format(baseline, FRC, FRD, smse, FRVar, FRDvs, TLCC)) 121 | 122 | 123 | 124 | def main(): 125 | args = parse_args() 126 | val(args) 127 | 128 | 129 | # --------------------------------------------------------------------------------- 130 | 131 | 132 | if __name__=="__main__": 133 | main() 134 | 135 | -------------------------------------------------------------------------------- /tool/audio_visual_clip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from tqdm import tqdm 4 | from moviepy.editor import * 5 | import os.path as opt 6 | 7 | 8 | FPS = 25 9 | STRIDE = 15 10 | LENGTH = 30 11 | 12 | 13 | def get_clip(reader): 14 | """get images of 15 seconds""" 15 | res = [] 16 | for i in range(STRIDE * FPS): 17 | __, img = reader.read() 18 | if img is None: 19 | return None 20 | res.append(img) 21 | return res 22 | 23 | 24 | def cut_single_video(in_path, out_path, indices): 25 | clip = VideoFileClip(in_path) 26 | for idx in indices: 27 | #st, et = (idx - 1) * STRIDE, (idx + 1) * STRIDE # for 30s clips with 15s of overlap 28 | st, et = (idx - 1) * LENGTH, (idx) * LENGTH # for 30s clips without overlap 29 | segment = clip.subclip(st, et) 30 | segment.write_videofile(opt.join(out_path, f"{idx}.mp4")) 31 | 32 | 33 | def cut_udiva_videos(in_base_path, out_base_path, session="animal"): 34 | with open("./udiva_list.csv", newline="") as f: 35 | reader = csv.DictReader(f) 36 | video_to_indices = {} 37 | for row in reader: 38 | video_to_indices.setdefault((row["subject"], row["topic"]), set()).add(int(row["index"])) 39 | session_mark = session[0].upper() 40 | for (subject, topic), indices in tqdm(video_to_indices.items()): 41 | if topic == session: 42 | for part in ["FC1", "FC2"]: 43 | in_path = os.path.join(os.path.join(in_base_path, subject), f"{part}_{session_mark}.mp4") 44 | out_path = os.path.join(os.path.join(out_base_path, subject), part) 45 | os.makedirs(out_path, exist_ok=True) 46 | cut_single_video(in_path, out_path, indices) 47 | 48 | def extract_udiva_videos(in_base_path, out_base_path, option="audio"): 49 | with open("./udiva_list.csv", newline="") as f: 50 | reader = csv.DictReader(f) 51 | video_to_indices = {} 52 | for row in reader: 53 | if len(row["subject"]) == 4: 54 | subject = "00" + row["subject"] 55 | elif len(row["subject"]) == 5: 56 | subject = "0" + row["subject"] 57 | else: 58 | subject = row["subject"] 59 | filename = row["topic"] + "/" + subject 60 | video_to_indices.setdefault(filename, set()).add(int(row["index"])) 61 | for filename, indices in tqdm(video_to_indices.items()): 62 | for part in ["FC1", "FC2"]: 63 | in_path = os.path.join(os.path.join(in_base_path, filename), part) 64 | out_path = os.path.join(os.path.join(out_base_path, filename), part) 65 | os.makedirs(out_path, exist_ok=True) 66 | if option == "audio": 67 | extract_audio(in_path, out_path, indices) 68 | elif option == "video": 69 | extract_video(in_path, out_path, indices) 70 | else: 71 | print("OPTION CAN NOT BE FOUND") 72 | return 73 | 74 | def cut_noxi_videos(in_base_path, out_base_path): 75 | with open("./noxi_list.csv", newline="") as f: 76 | reader = csv.DictReader(f) 77 | video_to_indices = {} 78 | for row in reader: 79 | minute = int(float(row["time"])) 80 | second = float(row["time"]) - minute 81 | number_of_clips = minute*2 + (second > 0.3) 82 | for idx in range(number_of_clips): 83 | video_to_indices.setdefault(row["filename"], set()).add(idx+1) 84 | for filename, indices in tqdm(video_to_indices.items()): 85 | for part in ["Expert_video", "Novice_video"]: 86 | in_path = os.path.join(os.path.join(in_base_path, filename), f"{part}.mp4") 87 | out_path = os.path.join(os.path.join(out_base_path, filename), part) 88 | os.makedirs(out_path, exist_ok=True) 89 | cut_single_video(in_path, out_path, indices) 90 | 91 | def extract_noxi_videos(in_base_path, out_base_path, option="audio"): 92 | """ 93 | Extract the audio and video files and split them. 94 | """ 95 | with open("./noxi_list.csv", newline="") as f: 96 | reader = csv.DictReader(f) 97 | video_to_indices = {} 98 | for row in reader: 99 | minute = int(float(row["time"])) 100 | second = float(row["time"]) - minute 101 | number_of_clips = minute*2 + (second > 0.3) 102 | for idx in range(number_of_clips): 103 | video_to_indices.setdefault(row["filename"], set()).add(idx+1) 104 | for filename, indices in tqdm(video_to_indices.items()): 105 | for part in ["Expert_video", "Novice_video"]: 106 | in_path = os.path.join(os.path.join(in_base_path, filename), part) 107 | out_path = os.path.join(os.path.join(out_base_path, filename), part) 108 | os.makedirs(out_path, exist_ok=True) 109 | if option == "audio": 110 | extract_audio(in_path, out_path, indices) 111 | elif option == "video": 112 | extract_video(in_path, out_path, indices) 113 | else: 114 | print("OPTION CAN NOT BE FOUND") 115 | return 116 | 117 | def cut_recola_videos(in_base_path, out_base_path): 118 | with open("./recola_list.csv", newline="") as f: 119 | reader = csv.DictReader(f) 120 | video_to_indices = {} 121 | for row in reader: 122 | number_of_clips = 10 123 | for idx in range(number_of_clips): 124 | video_to_indices.setdefault(row["filename"], set()).add(idx+1) 125 | for filename, indices in tqdm(video_to_indices.items()): 126 | in_path = os.path.join(in_base_path, filename + ".mp4") 127 | out_path = os.path.join(out_base_path, filename) 128 | os.makedirs(out_path, exist_ok=True) 129 | cut_single_video(in_path, out_path, indices) 130 | 131 | def extract_recola_videos(in_base_path, out_base_path, option="audio"): 132 | with open("./recola_list.csv", newline="") as f: 133 | reader = csv.DictReader(f) 134 | video_to_indices = {} 135 | for row in reader: 136 | number_of_clips = 10 137 | for idx in range(number_of_clips): 138 | video_to_indices.setdefault(row["filename"], set()).add(idx+1) 139 | for filename, indices in tqdm(video_to_indices.items()): 140 | in_path = os.path.join(in_base_path, filename) 141 | out_path = os.path.join(out_base_path, filename) 142 | os.makedirs(out_path, exist_ok=True) 143 | if option == "audio": 144 | extract_audio(in_path, out_path, indices) 145 | elif option == "video": 146 | extract_video(in_path, out_path, indices) 147 | else: 148 | print("OPTION CAN NOT BE FOUND") 149 | return 150 | 151 | def extract_audio(in_path, out_path, indices): 152 | for idx in indices: 153 | clip = VideoFileClip(opt.join(in_path, f"{idx}.mp4")) 154 | clip.audio.write_audiofile(opt.join(out_path, f"{idx}.wav")) 155 | 156 | def extract_video(in_path, out_path, indices): 157 | for idx in indices: 158 | clip = VideoFileClip(opt.join(in_path, f"{idx}.mp4")) 159 | new_clip = clip.without_audio() 160 | new_clip.write_videofile(opt.join(out_path, f"{idx}.mp4")) 161 | 162 | if __name__ == '__main__': 163 | #cut_recola_videos("/home/batubal/Desktop/RECOLA", "recola_videos") 164 | extract_recola_videos("/home/batubal/Desktop/recola_videos", "recola_video_files", option="video") 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /tool/matrix_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | import os 5 | 6 | 7 | def parse_arg(): 8 | parser = argparse.ArgumentParser(description='PyTorch Training') 9 | parser.add_argument('--dataset-path', default="./data", type=str, help="dataset path") 10 | parser.add_argument('--partition', default="val", type=str, help="dataset partition") 11 | args = parser.parse_args() 12 | return args 13 | 14 | args = parse_arg() 15 | root_path = args.dataset_path 16 | data_path = pd.read_csv(os.path.join(root_path, 'data_indices.csv'), header=None, delimiter=',') 17 | data_path = data_path.drop(0) 18 | all_matrix = np.load(os.path.join(root_path, 'Approprirate_facial_reaction.npy')) 19 | 20 | all_matrix_w = all_matrix.shape[0] 21 | print("all_matrix shape:", all_matrix.shape) 22 | 23 | val_path = pd.read_csv(os.path.join(root_path, str(args.partition)+'.csv'), header=None, delimiter=',') 24 | val_path = val_path.drop(0) 25 | 26 | index_list = [] 27 | _sum = 0 28 | for index, row in val_path.iterrows(): 29 | _sum += 1 30 | flag = 0 31 | items = row[1].split('/') 32 | if len(items)>4: 33 | path_val_item = os.path.join(items[0].upper(), items[2]+'_'+items[1], items[4]) 34 | else: 35 | path_val_item = os.path.join(items[0].upper(),items[1],items[3]) 36 | for index_2, row2 in data_path.iterrows(): 37 | path_data_index = os.path.join(row2[0].upper(), row2[1], row2[2]) 38 | if path_val_item == path_data_index: 39 | index_list.append(index_2-1) 40 | flag = 1 41 | 42 | l_m = len(index_list) 43 | new_matrix = np.zeros((l_m*2, l_m*2)) 44 | for index, item in enumerate(index_list): 45 | for j, item_j in enumerate(index_list): 46 | new_matrix[index, j] = all_matrix[item, item_j] 47 | new_matrix[index, j+l_m] = all_matrix[item, item_j + all_matrix_w//2] 48 | new_matrix[index+l_m, j] = all_matrix[item + all_matrix_w//2, item_j] 49 | new_matrix[index+l_m, j+l_m] = all_matrix[item + all_matrix_w//2, item_j + all_matrix_w//2] 50 | 51 | 52 | np.save(os.path.join(root_path, 'neighbour_emotion_' + str(args.partition) + '.npy'), new_matrix) 53 | print(new_matrix.shape) 54 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi 2 | import torch 3 | from torchvision import transforms 4 | from PIL import Image 5 | import torch.nn as nn 6 | import cv2 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from omegaconf import OmegaConf 10 | import os 11 | import yaml 12 | 13 | 14 | def load_config(config_path=None): 15 | cli_conf = OmegaConf.from_cli() 16 | model_conf = OmegaConf.load(cli_conf.pop('config') if config_path is None else config_path) 17 | return OmegaConf.merge(model_conf, cli_conf) 18 | 19 | def load_config_from_file(path): 20 | return OmegaConf.load(path) 21 | 22 | def store_config(config): 23 | # store config to directory 24 | dir = config.trainer.out_dir 25 | os.makedirs(dir, exist_ok=True) 26 | with open(os.path.join(dir, "config.yaml"), "w") as f: 27 | yaml.dump(OmegaConf.to_container(config), f) 28 | 29 | 30 | def torch_img_to_np(img): 31 | return img.detach().cpu().numpy().transpose(0, 2, 3, 1) 32 | 33 | 34 | def torch_img_to_np2(img): 35 | img = img.detach().cpu().numpy() 36 | # img = img * np.array([0.229, 0.224, 0.225]).reshape(1,-1,1,1) 37 | # img = img + np.array([0.485, 0.456, 0.406]).reshape(1,-1,1,1) 38 | img = img * np.array([0.5, 0.5, 0.5]).reshape(1,-1,1,1) 39 | img = img + np.array([0.5, 0.5, 0.5]).reshape(1,-1,1,1) 40 | img = img.transpose(0, 2, 3, 1) 41 | img = img * 255.0 42 | img = np.clip(img, 0, 255).astype(np.uint8)[:, :, :, [2, 1, 0]] 43 | 44 | return img 45 | 46 | 47 | def _fix_image(image): 48 | if image.max() < 30.: 49 | image = image * 255. 50 | image = np.clip(image, 0, 255).astype(np.uint8)[:, :, :, [2, 1, 0]] 51 | return image 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | 56 | def __init__(self): 57 | self.reset() 58 | 59 | def reset(self): 60 | self.val = 0 61 | self.avg = 0 62 | self.sum = 0 63 | self.count = 0 64 | 65 | def update(self, val, n=1): 66 | self.val = val 67 | self.sum += val * n 68 | self.count += n 69 | self.avg = self.sum / self.count 70 | 71 | 72 | --------------------------------------------------------------------------------