├── resources ├── PFED5.pdf ├── PFED5.png ├── pipeline.pdf ├── pipeline.png └── pd_results.png ├── dataset ├── PFED5_Request_Form.docx ├── generate_landmark_heatmap.py └── PAIN2landmarks_insightface.py ├── models ├── ST_Former.py ├── evaluator.py ├── slowonly.py ├── T_Former.py ├── vit_decoder_two.py ├── S_Former.py └── resnet3d.py ├── dataloader ├── ibmse │ ├── balancedMSE.py │ └── preprocess_gmm_PD.py ├── Parkinson_landmarkheatmap.py └── video_transform.py ├── README.md └── main_PFED5.py /resources/PFED5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuchaoduan/QAFE-Net/HEAD/resources/PFED5.pdf -------------------------------------------------------------------------------- /resources/PFED5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuchaoduan/QAFE-Net/HEAD/resources/PFED5.png -------------------------------------------------------------------------------- /resources/pipeline.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuchaoduan/QAFE-Net/HEAD/resources/pipeline.pdf -------------------------------------------------------------------------------- /resources/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuchaoduan/QAFE-Net/HEAD/resources/pipeline.png -------------------------------------------------------------------------------- /resources/pd_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuchaoduan/QAFE-Net/HEAD/resources/pd_results.png -------------------------------------------------------------------------------- /dataset/PFED5_Request_Form.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuchaoduan/QAFE-Net/HEAD/dataset/PFED5_Request_Form.docx -------------------------------------------------------------------------------- /models/ST_Former.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.S_Former import spatial_transformer 4 | from models.T_Former import temporal_transformer 5 | 6 | 7 | class GenerateModel(nn.Module): 8 | def __init__(self,cls_num=7): 9 | super().__init__() 10 | self.s_former = spatial_transformer() 11 | self.t_former = temporal_transformer() 12 | # self.fc = nn.Linear(512, cls_num) 13 | 14 | 15 | def forward(self, x): 16 | n_batch, frames, _, _, _ = x.shape 17 | n_clips = int(frames/16) 18 | # split video sequence into n segments and pack them 19 | if frames>16: 20 | data_pack = torch.cat([x[:,i:i+16] for i in range(0, frames-1, 16)]) 21 | out_s= self.s_former(data_pack) 22 | else: 23 | # out_s = self.s_former(x) 24 | out_s= self.s_former(x)# [] 25 | out_t = self.t_former(out_s) 26 | return out_t 27 | 28 | 29 | if __name__ == '__main__': 30 | img = torch.randn((2, 80, 3, 224, 224)) 31 | model = GenerateModel(cls_num=5) 32 | model(img) 33 | -------------------------------------------------------------------------------- /models/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | 5 | class MLP_block(nn.Module): 6 | 7 | def __init__(self, output_dim): 8 | super(MLP_block, self).__init__() 9 | self.activation = nn.ReLU() 10 | # self.softmax = nn.Softmax(dim=-1) 11 | 12 | self.layer1 = nn.Linear(512, 256) 13 | self.layer2 = nn.Linear(256, 128) 14 | self.layer3 = nn.Linear(128, output_dim) 15 | 16 | def forward(self, x): 17 | x = self.activation(self.layer1(x)) 18 | x = self.activation(self.layer2(x)) 19 | # output = self.softmax(self.layer3(x)) 20 | output = self.layer3(x) 21 | return output 22 | 23 | class single_fc(nn.Module): 24 | 25 | def __init__(self, output_dim): 26 | super(single_fc, self).__init__() 27 | self.fc = nn.Linear(512, output_dim) 28 | 29 | def forward(self, x): 30 | output = self.fc(x) 31 | return output 32 | 33 | class two_fc(nn.Module): 34 | 35 | def __init__(self, output_dim): 36 | super(two_fc, self).__init__() 37 | self.activation = nn.ReLU() 38 | self.layer1 = nn.Linear(512, 256) 39 | self.layer2 = nn.Linear(256, output_dim) 40 | 41 | def forward(self, x): 42 | x = self.activation(self.layer1(x)) 43 | output = self.layer2(x) 44 | return output 45 | 46 | 47 | class Evaluator(nn.Module): 48 | 49 | def __init__(self, output_dim, model_type='MLP'): 50 | super(Evaluator, self).__init__() 51 | 52 | self.model_type = model_type 53 | 54 | if model_type == 'MLP': 55 | self.evaluator = MLP_block(output_dim=output_dim) 56 | else: # classification 57 | self.evaluator = single_fc(output_dim=output_dim) 58 | 59 | 60 | def forward(self, feats_avg): # data: NCTHW 61 | 62 | probs = self.evaluator(feats_avg) # Nxoutput_dim 63 | 64 | return probs 65 | 66 | -------------------------------------------------------------------------------- /models/slowonly.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.resnet3d import ResNet3d 6 | 7 | 8 | class ResNet3dSlowOnly(nn.Module): 9 | """SlowOnly backbone based on ResNet3d. 10 | Args: 11 | conv1_kernel (tuple[int]): Kernel size of the first conv layer. Default: (1, 7, 7). 12 | inflate (tuple[int]): Inflate Dims of each block. Default: (0, 0, 1, 1). 13 | **kwargs (keyword arguments): Other keywords arguments for 'ResNet3d'. 14 | """ 15 | 16 | def __init__(self, conv1_kernel=(1, 7, 7), inflate= (0, 1, 1), 17 | in_channels = 3, base_channels = 32, num_stages = 3, 18 | out_indices = (2,), stage_blocks = (4, 6, 3), conv1_stride = (1, 1), 19 | pool1_stride=(1, 1), spatial_strides=(2, 2, 2), 20 | temporal_strides=(1, 1, 1), **kwargs): 21 | super().__init__() 22 | self.restnet3d = ResNet3d(conv1_kernel=conv1_kernel, inflate=inflate, 23 | in_channels=in_channels, base_channels=base_channels, num_stages=num_stages, 24 | out_indices=out_indices, stage_blocks=stage_blocks, conv1_stride=conv1_stride, 25 | pool1_stride=pool1_stride, spatial_strides=spatial_strides, 26 | temporal_strides=temporal_strides, **kwargs) 27 | self.avg_pool = torch.nn.AvgPool3d((16,7,7)) 28 | 29 | def forward(self, x): 30 | batch = x.shape[0] 31 | feature = self.restnet3d(x) 32 | out = self.avg_pool(feature) 33 | out = out.reshape(batch, 512) 34 | return out 35 | 36 | def load_weights(self): 37 | weight = torch.load('./models/pose_only.pth', map_location='cpu') 38 | load_dict = {k[9:]: v for k,v in weight['state_dict'].items() } 39 | load_dict.pop('conv1.conv.weight') 40 | load_dict.pop('fc_cls.bias') 41 | load_dict.pop('fc_cls.weight') 42 | self.restnet3d.load_state_dict(load_dict, strict=False) 43 | print('slowonly weights loaded') 44 | 45 | 46 | if __name__=='__main__': 47 | x = torch.randn((2,3,16,56,56)) 48 | model = ResNet3dSlowOnly() 49 | model.load_weights() 50 | y = model(x) 51 | print(y) 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /dataset/generate_landmark_heatmap.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def gen_an_aug(results, sigma=1): 6 | all_kps = results['keypoint'] 7 | kp_shape = all_kps.shape 8 | all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32) 9 | img_h, img_w = results['img_shape'] 10 | num_frame = kp_shape[0] 11 | 12 | imgs = [] 13 | for i in range(num_frame): 14 | sigma = sigma 15 | kps = all_kps[i, :] 16 | kpscores = all_kpscores[i,:] 17 | max_values = np.ones(kpscores.shape, dtype=np.float32) 18 | hmap = generate_heatmap(img_h, img_w, kps, sigma, max_values) 19 | combined_map = np.sum(hmap, axis=0) 20 | # Apply Gaussian blur to smooth the combined heatmap 21 | smoothed_heatmap = cv2.GaussianBlur(combined_map, (7, 7), sigmaX=0) 22 | # Normalize the smoothed heatmap 23 | smoothed_heatmap /= np.max(smoothed_heatmap) 24 | # Convert the smoothed heatmap to an image 25 | smoothed_heatmap = cv2.applyColorMap(np.uint8(255 * smoothed_heatmap), cv2.COLORMAP_JET) 26 | # Display the heatmap 27 | # cv2.imshow('Heatmap', blurred_heatmap) 28 | # cv2.waitKey(0) 29 | # cv2.destroyAllWindows() 30 | imgs.append(smoothed_heatmap) 31 | return imgs 32 | 33 | def generate_heatmap(img_h, img_w, kps, sigma, max_values, with_kp =True): 34 | #Generate pseudo heatmap for all keypoints in one frame 35 | heatmaps = [] 36 | if with_kp: 37 | num_kp = kps.shape[0] 38 | 39 | for i in range(1, num_kp): 40 | heatmap = generate_a_heatmap(img_h, img_w, kps[i, :], 41 | sigma, max_values[i]) 42 | heatmaps.append(heatmap) 43 | 44 | return np.stack(heatmaps, axis=0) 45 | 46 | def generate_a_heatmap(img_h, img_w, centers, sigma, max_values,): 47 | heatmap = np.zeros([img_h, img_w], dtype=np.float32) 48 | if len(centers) == 2: 49 | mu_x, mu_y = centers[0].astype(np.float32), centers[1].astype(np.float32) 50 | st_x = max(int(mu_x - 5), 0) 51 | ed_x = min(int(mu_x + 5) + 1, img_w) 52 | st_y = max(int(mu_y - 5), 0) 53 | ed_y = min(int(mu_y + 5) + 1, img_h) 54 | x = np.arange(st_x, ed_x, 1, np.float32) 55 | y = np.arange(st_y, ed_y, 1, np.float32) 56 | y = y[:, None] 57 | 58 | patch = np.exp(-((x - mu_x) ** 2 + (y - mu_y) ** 2) / 2 / sigma ** 2) 59 | patch = patch * max_values 60 | a = heatmap[st_y:ed_y, st_x:ed_x] 61 | b = np.maximum(a, patch) 62 | heatmap[st_y:ed_y, st_x:ed_x] = b 63 | 64 | return heatmap / np.max(heatmap) 65 | 66 | if __name__=='__main__': 67 | item = {} 68 | item['keypoint'] = 'landmark coordinates' 69 | item['img_shape'] = (256, 256) 70 | heatmap_set = gen_an_aug(item, sigma=1) -------------------------------------------------------------------------------- /dataloader/ibmse/balancedMSE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.modules.loss import _Loss 4 | import joblib 5 | 6 | 7 | class GAILoss(_Loss): 8 | def __init__(self, init_noise_sigma, gmm): 9 | super(GAILoss, self).__init__() 10 | self.gmm = joblib.load(gmm) 11 | self.gmm = {k: torch.tensor(self.gmm[k]).cuda() for k in self.gmm} 12 | self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma, device="cuda")) 13 | 14 | def forward(self, pred, target): 15 | noise_var = self.noise_sigma ** 2 16 | loss = gai_loss(pred, target, self.gmm, noise_var) 17 | return loss 18 | 19 | 20 | # def gai_loss(pred, target, gmm, noise_var): 21 | # gmm = {k: gmm[k].reshape(1, -1).expand(pred.shape[0], -1) for k in gmm} 22 | # mse_term = F.mse_loss(pred, target, reduction='none') / 2 / noise_var + 0.5 * noise_var.log() 23 | # sum_var = gmm['variances'] + noise_var 24 | # balancing_term = - 0.5 * sum_var.log() - 0.5 * (pred - gmm['means']).pow(2) / sum_var + gmm['weights'].log() 25 | # balancing_term = torch.logsumexp(balancing_term, dim=-1, keepdim=True) 26 | # loss = mse_term + balancing_term 27 | # loss = loss * (2 * noise_var).detach() 28 | 29 | # return loss.mean() 30 | 31 | 32 | class BMCLoss(_Loss): 33 | def __init__(self, init_noise_sigma): 34 | super(BMCLoss, self).__init__() 35 | self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma, device="cuda")) 36 | 37 | def forward(self, pred, target): 38 | noise_var = self.noise_sigma ** 2 39 | loss = bmc_loss(pred, target, noise_var) 40 | return loss 41 | 42 | 43 | def bmc_loss(pred, target, noise_var): 44 | logits = - 0.5 * (pred - target.T).pow(2) / noise_var 45 | loss = F.cross_entropy(logits, torch.arange(pred.shape[0]).cuda()) 46 | loss = loss * (2 * noise_var).detach() 47 | 48 | return loss 49 | 50 | 51 | 52 | class FocalRLoss(_Loss): 53 | def __init__(self, weights=None, activate='sigmoid', beta=.2, gamma=1): 54 | super(FocalRLoss, self).__init__() 55 | self.weights = weights 56 | self.activate = activate 57 | self.beta = beta 58 | self.gamma = gamma 59 | 60 | def forward(self, pred, target): 61 | loss = weighted_focal_mse_loss(pred, target, self.weights, self.activate, self.beta, self.gamma) 62 | return loss 63 | 64 | def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1): 65 | loss = (inputs - targets) ** 2 66 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 67 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 68 | if weights is not None: 69 | loss *= weights.expand_as(loss) 70 | loss = torch.mean(loss) 71 | return loss 72 | 73 | def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1): 74 | loss = F.l1_loss(inputs, targets, reduction='none') 75 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 76 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 77 | if weights is not None: 78 | loss *= weights.expand_as(loss) 79 | loss = torch.mean(loss) 80 | return loss 81 | 82 | if __name__=='__main__': 83 | loss = GAILoss(init_noise_sigma=1., gmm='./dataloader/gmm_PD_0.pkl') 84 | pred= torch.randn((2,1)).cuda() 85 | target = torch.randn((2,1)).cuda() 86 | p = loss(pred, target) 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QAFE-Net 2 | 3 | This is the PyTorch implementation for QAFE-Net: Quality Assessment of Facial Expressions with Landmark Heatmaps 4 | [arXiv version](https://arxiv.org/abs/2312.00856) 5 | 6 | We propose a novel landmark-guided approach, QAFE-Net, that combines temporal landmark heatmaps with RGB data to capture small facial muscle movements that are encoded and mapped to severity scores. 7 | ![pipeline](resources/pipeline.png) 8 | 9 | ## PFED5 Dataset 10 | 11 | 12 | 13 | 14 | PFED5 is a Parkinson’s disease (PD) dataset for facial expression quality assessment. Videos were recorded using a single RGB camera from 41 PD patients performed five different facial expressions including sit at rest, smile, frown, squeeze eyes tightly, and clench teeth in clinical settings. The trained rater assigned a score for each expression, based on the protocols of MDS-UPDRS, varying between 0 and 4 depending on the level of severity. 15 | 16 | ## Get Started 17 | ### Requirements 18 | pytroch >= 1.3.0, mmcv = 1.x, tensorboardX, cv2, scipy, einops, [torch_videovision](https://github.com/hassony2/torch_videovision) 19 | 20 | ### Data Download 21 | 22 | 1: PFED5 dataset. To access the PFED5 dataset, please complete and sign the [PFED5 request form](dataset/PFED5_Request_Form.docx) and forward it to shuchao.duan@bristol.ac.uk. By submitting your application, you acknowledge and confirm that you have read and understood the relevant notice. Upon receiving your request, we will promptly respond with the necessary link and guidelines. Please note that ONLY faculty members can request for their team to be granted access to the dataset. 23 | 24 | 2: [UNBC-McMaster](https://www.jeffcohn.net/Resources/) dataset. 25 | 26 | ### Data preparation 27 | 28 | 1: We adopt SCRFD from [InsightFace](https://insightface.ai) for face detection and landmark estimation, 29 | and [Albumentation library](https://albumentations.ai) for normalising the landmark positions to cropped face regions. 30 | 31 | 2: Generate landmark heatmaps by using Gaussian weights for corresponding video clips. 32 | 33 | Please refer to the [code](dataset/) for more details. 34 | 35 | ### Training and Testing on PFED5 36 | run ```python main_PFED5.py --gpu 0,1 --batch_size 4 --epoch 100``` 37 | 38 | ### Pretrained Weights 39 | Download pretrain weights (RGB encoder and heatmap encoder) from [Google Drive](https://drive.google.com/drive/folders/1tq1s3uoXiV8ZohrUtuHgJgaZwfGbjWY-?usp=drive_link). Put entire `pretrain` folder under `models` folder. 40 | 41 | ``` 42 | - models/pretrain/ 43 | FormerDFER-DFEWset1-model_best.pth 44 | pose_only.pth 45 | ``` 46 | 47 | ### Evaluation Results 48 | Comparative Spearman's Rank Correlation results of QAFE-Net with SOTA AQA methods on PFED5 49 | ![results_pd](resources/pd_results.png) 50 | 51 | ## Citations 52 | If you find our work useful in your research, please consider giving it a star ⭐ and citing our paper in your work: 53 | 54 | ``` 55 | @misc{duan2023qafenet, 56 | title={QAFE-Net: Quality Assessment of Facial Expressions with Landmark Heatmaps}, 57 | author={Shuchao Duan and Amirhossein Dadashzadeh and Alan Whone and Majid Mirmehdi}, 58 | year={2023}, 59 | eprint={2312.00856}, 60 | archivePrefix={arXiv}, 61 | primaryClass={cs.CV} 62 | } 63 | 64 | ``` 65 | 66 | ## Acknowlegement 67 | We would like to gratefully acknowledge the contribution of the Parkinson’s study participants and extend special appreciation to Tom Whone for his additional labelling efforts. The clinical trial from which the video data of the people with Parkinson’s was sourced was funded by Parkinson’s UK (Grant J-1102), with support from Cure Parkinson’s. 68 | 69 | Our implementation and experiments are built on top of [Former-DFER](https://github.com/zengqunzhao/Former-DFER). We thank the authors who made their code public, which tremendously accelerated our project progress. 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /models/T_Former.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from torch import nn, einsum 4 | import math 5 | 6 | 7 | class GELU(nn.Module): 8 | def forward(self, x): 9 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 10 | 11 | 12 | class Residual(nn.Module): 13 | def __init__(self, fn): 14 | super().__init__() 15 | self.fn = fn 16 | 17 | def forward(self, x, **kwargs): 18 | return self.fn(x, **kwargs) + x 19 | 20 | 21 | class PreNorm(nn.Module): 22 | def __init__(self, dim, fn): 23 | super().__init__() 24 | self.norm = nn.LayerNorm(dim) 25 | self.fn = fn 26 | 27 | def forward(self, x, **kwargs): 28 | return self.fn(self.norm(x), **kwargs) 29 | 30 | 31 | class FeedForward(nn.Module): 32 | def __init__(self, dim, hidden_dim, dropout=0.): 33 | super().__init__() 34 | self.net = nn.Sequential(nn.Linear(dim, hidden_dim), 35 | GELU(), 36 | nn.Dropout(dropout), 37 | nn.Linear(hidden_dim, dim), 38 | nn.Dropout(dropout)) 39 | 40 | def forward(self, x): 41 | return self.net(x) 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 46 | super().__init__() 47 | inner_dim = dim_head * heads 48 | project_out = not (heads == 1 and dim_head == dim) 49 | self.heads = heads 50 | self.scale = dim_head ** -0.5 51 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 52 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), 53 | nn.Dropout(dropout)) if project_out else nn.Identity() 54 | 55 | def forward(self, x): 56 | b, n, _, h = *x.shape, self.heads 57 | qkv = self.to_qkv(x).chunk(3, dim=-1) 58 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 59 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 60 | attn = dots.softmax(dim=-1) 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | out = self.to_out(out) 64 | return out 65 | 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))), 73 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))])) 74 | 75 | def forward(self, x): 76 | for attn, ff in self.layers: 77 | x = attn(x) 78 | x = ff(x) 79 | return x 80 | 81 | 82 | class TFormer(nn.Module): 83 | def __init__(self, num_patches=16, dim=512, depth=3, heads=8, mlp_dim=1024, dim_head=64, dropout=0.0): 84 | super().__init__() 85 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 86 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim)) 87 | self.temporal_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 88 | 89 | def forward(self, x): 90 | x = x.contiguous().view(-1, 16, 512) 91 | b, n, _ = x.shape 92 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 93 | x = torch.cat((cls_tokens, x), dim=1) 94 | x = x + self.pos_embedding[:, :(n+1)] 95 | x = self.temporal_transformer(x) 96 | x = x[:, 0] 97 | 98 | return x 99 | 100 | 101 | def temporal_transformer(): 102 | return TFormer() 103 | 104 | 105 | if __name__ == '__main__': 106 | img = torch.randn((1, 16, 3, 112, 112)) 107 | model = temporal_transformer() 108 | model(img) 109 | -------------------------------------------------------------------------------- /models/vit_decoder_two.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from timm.models.layers import DropPath, trunc_normal_ 7 | import numpy as np 8 | 9 | 10 | class Mlp(nn.Module): 11 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 12 | super().__init__() 13 | out_features = out_features or in_features 14 | hidden_features = hidden_features or in_features 15 | self.fc1 = nn.Linear(in_features, hidden_features) 16 | self.act = act_layer() 17 | self.fc2 = nn.Linear(hidden_features, out_features) 18 | self.drop = nn.Dropout(drop) 19 | 20 | def forward(self, x): 21 | x = self.fc1(x) 22 | x = self.act(x) 23 | x = self.drop(x) 24 | x = self.fc2(x) 25 | x = self.drop(x) 26 | return x 27 | 28 | 29 | 30 | 31 | class CrossAttention(nn.Module): 32 | def __init__(self, dim, out_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 33 | super().__init__() 34 | self.num_heads = num_heads 35 | self.dim = dim 36 | self.out_dim = out_dim 37 | head_dim = out_dim // num_heads 38 | self.scale = qk_scale or head_dim ** -0.5 39 | 40 | self.q_map = nn.Linear(dim, out_dim, bias=qkv_bias) 41 | self.k_map = nn.Linear(dim, out_dim, bias=qkv_bias) 42 | self.v_map = nn.Linear(dim, out_dim, bias=qkv_bias) 43 | self.attn_drop = nn.Dropout(attn_drop) 44 | 45 | self.proj = nn.Linear(out_dim, out_dim) 46 | self.proj_drop = nn.Dropout(proj_drop) 47 | 48 | def forward(self, q, v): 49 | B, N, _ = q.shape 50 | C = self.out_dim 51 | k = v 52 | NK = k.size(1) 53 | 54 | q = self.q_map(q).view(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 55 | k = self.k_map(k).view(B, NK, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 56 | v = self.v_map(v).view(B, NK, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 57 | 58 | attn = (q @ k.transpose(-2, -1)) * self.scale 59 | attn = attn.softmax(dim=-1) 60 | attn = self.attn_drop(attn) 61 | 62 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 63 | x = self.proj(x) 64 | x = self.proj_drop(x) 65 | return x 66 | 67 | 68 | 69 | class DecoderBlock(nn.Module): 70 | def __init__(self, dim, num_heads, dim_q=None, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 71 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 72 | super().__init__() 73 | dim_q = dim_q or dim 74 | self.img_dim = 5 75 | self.norm_q = norm_layer(dim_q) 76 | self.norm_v = norm_layer(dim) 77 | self.attn = CrossAttention( 78 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 79 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 80 | self.norm2 = norm_layer(dim) 81 | mlp_hidden_dim = int(dim * mlp_ratio) 82 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 83 | 84 | def forward(self, x): 85 | x_1 = x[:, :self.img_dim, :] 86 | x_2 = x[:, self.img_dim:, :] 87 | 88 | x_frame = x_1 + self.drop_path(self.attn(self.norm_q(x_1), self.norm_v(x_2))) 89 | x_frame = x_frame + self.drop_path(self.mlp(self.norm2(x_frame))) 90 | 91 | x_heatmap = x_2 + self.drop_path(self.attn(self.norm_q(x_2), self.norm_v(x_1))) 92 | x_heatmap = x_heatmap + self.drop_path(self.mlp(self.norm2(x_heatmap))) 93 | x = torch.cat((x_frame, x_heatmap), dim=1) 94 | 95 | return x 96 | 97 | 98 | class decoder_fuser(nn.Module): 99 | def __init__(self, dim, num_heads, num_layers, drop_rate): 100 | super(decoder_fuser, self).__init__() 101 | model_list = [] 102 | for i in range(num_layers): 103 | model_list.append(DecoderBlock(dim, num_heads)) 104 | self.model = nn.ModuleList(model_list) 105 | self.pos_embed = nn.Parameter(torch.randn(1, 5, dim)) 106 | self.pos_drop = nn.Dropout(drop_rate) 107 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 108 | self.norm = norm_layer(dim) 109 | self.conv = nn.Conv1d(5*2, 5*2, 1) 110 | 111 | def forward(self, x_frame, x_heatmap): 112 | x_frame = self.pos_drop(x_frame+self.pos_embed) 113 | x_comb = torch.cat((x_frame, x_heatmap), dim=1) 114 | 115 | for _layer in self.model: 116 | x_comb = _layer(x_comb) 117 | x_comb = self.conv(x_comb) 118 | 119 | return x_comb 120 | 121 | if __name__=='__main__': 122 | x = torch.randn((2,5, 512)) 123 | model = decoder_fuser(dim=512, num_heads=8, num_layers=3, drop_rate=0.) 124 | y = model(x,x) 125 | print(y) 126 | -------------------------------------------------------------------------------- /dataloader/ibmse/preprocess_gmm_PD.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import DataLoader 3 | import pandas as pd 4 | import argparse 5 | import os 6 | import time 7 | import joblib 8 | import torch 9 | from sklearn.mixture import GaussianMixture 10 | from tqdm import tqdm 11 | 12 | from dataloader.Parkinson_multi import train_data_loader 13 | from main_PD_expression import parser 14 | 15 | # 16 | # parser = argparse.ArgumentParser(description='') 17 | # # Default args 18 | # # training/optimization related 19 | # parser.add_argument('--dataset', type=str, default='imdb_wiki', choices=['imdb_wiki', 'agedb'], help='dataset name') 20 | # parser.add_argument('--data_dir', type=str, default='./data', help='data directory') 21 | # parser.add_argument('--model', type=str, default='resnet50', help='model name') 22 | # parser.add_argument('--store_root', type=str, default='checkpoint', help='root path for storing checkpoints, logs') 23 | # parser.add_argument('--store_name', type=str, default='', help='experiment store name') 24 | # parser.add_argument('--gpu', type=int, default=None) 25 | # parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='optimizer type') 26 | # parser.add_argument('--loss', type=str, default='l1', choices=['mse', 'l1', 'focal_l1', 'focal_mse', 'huber'], help='training loss type') 27 | # parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate') 28 | # parser.add_argument('--epoch', type=int, default=90, help='number of epochs to train') 29 | # parser.add_argument('--momentum', type=float, default=0.9, help='optimizer momentum') 30 | # parser.add_argument('--weight_decay', type=float, default=1e-4, help='optimizer weight decay') 31 | # parser.add_argument('--schedule', type=int, nargs='*', default=[60, 80], help='lr schedule (when to drop lr by 10x)') 32 | # parser.add_argument('--batch_size', type=int, default=256, help='batch size') 33 | # parser.add_argument('--print_freq', type=int, default=10, help='logging frequency') 34 | # parser.add_argument('--img_size', type=int, default=224, help='image size used in training') 35 | # parser.add_argument('--workers', type=int, default=32, help='number of workers used in data loading') 36 | # 37 | # parser.add_argument('--reweight', type=str, default='none', choices=['none', 'inverse', 'sqrt_inv'], 38 | # help='cost-sensitive reweighting scheme') 39 | # # LDS 40 | # parser.add_argument('--lds', action='store_true', default=False, help='whether to enable LDS') 41 | # parser.add_argument('--lds_kernel', type=str, default='gaussian', 42 | # choices=['gaussian', 'triang', 'laplace'], help='LDS kernel type') 43 | # parser.add_argument('--lds_ks', type=int, default=5, help='LDS kernel size: should be odd number') 44 | # parser.add_argument('--lds_sigma', type=float, default=1, help='LDS gaussian/laplace kernel sigma') 45 | 46 | # Args for GMM 47 | parser.add_argument('--K', type=int, default=2, help='GMM number of components')# follow IMDB-WIKI_DIR whose test sei is imbalanced 48 | 49 | 50 | def prerpocess_gmm(): 51 | args = parser.parse_args() 52 | # Data 53 | end_time = time.time() 54 | print('Getting Train Loader...') 55 | train_data = train_data_loader(args) 56 | 57 | train_loader = torch.utils.data.DataLoader(train_data, 58 | batch_size=args.batch_size, 59 | shuffle=True, 60 | num_workers=args.workers, 61 | pin_memory=True) 62 | # df = pd.read_csv(os.path.join(args.data_dir, f"{args.dataset}.csv")) 63 | # df_train = df[df['split'] == 'train'] 64 | # train_dataset = IMDBWIKI(data_dir=args.data_dir, df=df_train, img_size=args.img_size, split='train', 65 | # reweight=args.reweight, lds=args.lds, lds_kernel=args.lds_kernel, lds_ks=args.lds_ks, lds_sigma=args.lds_sigma) 66 | # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 67 | # num_workers=args.workers, pin_memory=True, drop_last=False) 68 | print(time.time() - end_time) 69 | end_time = time.time() 70 | print('Training Loader Done.') 71 | print('Curate training labels...') 72 | all_labels = [] 73 | for _, data in tqdm(enumerate(train_loader)): 74 | targets = data['final_score'] 75 | all_labels.append(targets) 76 | all_labels = torch.cat(all_labels).reshape(1, -1) 77 | print('All labels shape: ', all_labels.shape) 78 | print(time.time() - end_time) 79 | end_time = time.time() 80 | print('Training labels curated') 81 | print('Fitting GMM...') 82 | gmm = GaussianMixture(n_components=args.K, random_state=0, verbose=2).fit( 83 | all_labels.reshape(-1, 1).cpu().numpy()) 84 | print(time.time() - end_time) 85 | end_time = time.time() 86 | print('GMM fiited') 87 | print("Dumping...") 88 | gmm_dict = {} 89 | gmm_dict['means'] = gmm.means_ 90 | gmm_dict['weights'] = gmm.weights_ 91 | gmm_dict['variances'] = gmm.covariances_ 92 | gmm_path = './gmm_PD_{}.pkl'.format(args.class_idx) 93 | joblib.dump(gmm_dict, gmm_path) 94 | print(time.time() - end_time) 95 | print('Dumped at {}'.format(gmm_path)) 96 | 97 | 98 | if __name__ == '__main__': 99 | prerpocess_gmm() -------------------------------------------------------------------------------- /dataset/PAIN2landmarks_insightface.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | import glob 4 | import os 5 | import sys 6 | import urllib.request 7 | import urllib.error 8 | import time 9 | 10 | import insightface 11 | import matplotlib.pyplot as plt 12 | from insightface.app import FaceAnalysis 13 | from insightface.data import get_image as ins_get_image 14 | import albumentations as A 15 | import cv2 16 | import dlib 17 | import numpy as np 18 | from PIL.Image import Image 19 | from tqdm import tqdm 20 | 21 | from opencv_zoo.models.face_detection_yunet import detect 22 | import numpy as np 23 | 24 | class Logger(object): 25 | def __init__(self, filename): 26 | self.terminal = sys.stdout 27 | self.log = open(filename, 'a') 28 | def write(self, message): 29 | self.terminal.write(message) 30 | self.log.write(message) 31 | def flush(self): 32 | pass 33 | 34 | type = sys.getfilesystemencoding() 35 | sys.stdout = Logger('relative_distance_over_first frame.txt') 36 | app = FaceAnalysis(allowed_modules=['detection', 'landmark_2d_106']) 37 | app.prepare(ctx_id=0, det_size=(640, 640)) 38 | 39 | transform = A.Compose([A.Resize(256, 256) 40 | ], keypoint_params=A.KeypointParams(format='xy')) 41 | 42 | # ids = ['042-ll042', '049-bm049', '066-mg066 ', '096-bg096', '106-nm106', 43 | # ' 115-jy115 ', '124-dn124', '043-jh043', ' 052-dr052', '080-bn080', 44 | # ' 097-gf097' , '107-hs107 ', '120-kz120', '047-jl047', '059-fn059 ', 45 | # ' 092-ch092', '101-mg101 ', '108-th108', '121-vw121','048-aa048 ' , 46 | # ' 064-ak064 ', '095-tv095', '103-jk103', '109-ib109', '123-jh123' 47 | # ] 48 | ids = [ '095-tv095' 49 | ] 50 | 51 | root = r"/path/dataset/PAIN/Images" 52 | crop_root = r"/path/dataset/PAIN/Images_crop" 53 | visual_root = r"/path/dataset/PAIN/Images_landmarks_clear" 54 | feature_root = r"/path/dataset/PAIN/Images_landmarks_feature" 55 | color = (0 ,255, 0) 56 | 57 | for id in tqdm(ids): 58 | summary_info = [] 59 | vid_folder = os.path.join(root, id) 60 | for dirpath, dirnames, filenames in os.walk(vid_folder): 61 | for vid in tqdm(dirnames): 62 | vid_path = os.path.join(dirpath, vid) 63 | img_list = glob.glob(os.path.join(vid_path, '*.png')) 64 | crop_out_dir = os.path.join(crop_root, id, vid) 65 | visual_out_dir = os.path.join(visual_root, id, vid) 66 | feature_save_file = '{}-{}.npy'.format(vid, id) 67 | feature_out_dir = os.path.join(feature_root, id) 68 | 69 | full_feature = [] 70 | if not os.path.exists(feature_out_dir): 71 | os.makedirs(feature_out_dir) 72 | if not os.path.exists(crop_out_dir): 73 | os.makedirs(crop_out_dir) 74 | if not os.path.exists(visual_out_dir): 75 | os.makedirs(visual_out_dir) 76 | 77 | for img_path in sorted(img_list): 78 | visual_save_path = os.path.join(visual_out_dir, img_path.split('/')[-1]) 79 | crop_save_path = os.path.join(crop_out_dir, img_path.split('/')[-1]) 80 | sub_feature = [] 81 | 82 | # #insightface 83 | img = cv2.imread(img_path) 84 | faces = app.get(img) 85 | if faces is not None: 86 | for i, face in enumerate(faces): 87 | if i >0: 88 | continue 89 | lmk = face.landmark_2d_106 90 | fb = np.round(face.bbox).astype(np.int32) 91 | top, left, bottom, right = fb[0],fb[1],fb[2],fb[3] 92 | if top < 0 and left >= 0: 93 | img_crop, top = img[left:right, 0:bottom], 0 94 | elif top >= 0 and left < 0: 95 | img_crop, left = img[0:right, top:bottom], 0 96 | elif top < 0 and left < 0: 97 | img_crop, top, left = img[0:right, 0:bottom], 0 , 0 98 | else: 99 | img_crop = img[left:right, top:bottom] 100 | 101 | lmk = np.round(lmk).astype(np.int32) 102 | 103 | new_lmk = [] # landmarks on cropped face 104 | for i in range(33, lmk.shape[0]): # 106 landmarks remove the contour part 105 | new_coor = np.round((lmk[i] - [top, left])) 106 | if new_coor[0] < 0: 107 | new_coor[0] = 0 108 | if new_coor[1] < 0: 109 | new_coor[1] = 0 110 | if new_coor[0] >= img_crop.shape[1]: 111 | new_coor[0] = img_crop.shape[1] - 1 112 | if new_coor[1] >= img_crop.shape[0]: 113 | new_coor[1] = img_crop.shape[0] - 1 114 | new_lmk.append(tuple(new_coor)) 115 | try: 116 | transformed = transform(image=img_crop, keypoints=new_lmk) # resize to (256,256) 117 | transformed_image = transformed['image'] 118 | transformed_keypoints = transformed['keypoints'] 119 | transformed_keypoints = np.round(transformed_keypoints).astype(np.int32) 120 | except: 121 | print(img_path,fb) 122 | continue 123 | img_crop = cv2.resize(img_crop, (256, 256)) 124 | 125 | for i in range(0, len(new_lmk)): # 73 landmarks 126 | p = tuple(transformed_keypoints[i]) 127 | cv2.circle(transformed_image, p, 1, color, 2) 128 | 129 | # Convert the facial landmarks into a feature vector 130 | sub_feature.append(np.around(np.array([transformed_keypoints[i][0], transformed_keypoints[i][1]]), 2)) 131 | # feature_vector = np.array([shape.part(i).x, shape.part(i).y for i in range(68)]) 132 | # if len(faces)==0 and len(full_feature)>0: 133 | # sub_feature = full_feature[-1] 134 | cv2.imwrite(visual_save_path, transformed_image) 135 | cv2.imwrite(crop_save_path, img_crop) 136 | full_feature.append(sub_feature) 137 | try: 138 | np.save(os.path.join(feature_out_dir,feature_save_file), 139 | full_feature) 140 | except: 141 | print(feature_save_file, len(full_feature), full_feature) 142 | 143 | summary_info.append([vid, id, len(img_list)]) 144 | 145 | print(' {} done'.format(id)) 146 | # compute the distance between centroid and facial landmarks 147 | print('done') -------------------------------------------------------------------------------- /dataloader/Parkinson_landmarkheatmap.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | import cv2 6 | import scipy 7 | import torch 8 | import pandas as pd 9 | from PIL import Image 10 | 11 | from scipy import stats 12 | import numpy as np 13 | import glob 14 | import torchvision 15 | from torchvideotransforms import video_transforms, volume_transforms 16 | 17 | 18 | class Parkinson_former(torch.utils.data.Dataset): 19 | def __init__(self, args, subset): 20 | 21 | self.subset = subset 22 | # self.transform = transform 23 | 24 | self.class_idx = args.class_idx # sport class index(from 0 begin) 25 | 26 | self.args = args 27 | self.clip_len = args.clip_len 28 | self.data_root = args.data_root 29 | self.landmarks_root = args.landmarks_root 30 | self.split_path = os.path.join(self.data_root, 'PFED5_train.csv') 31 | self.split_data = pd.read_csv(self.split_path) 32 | self.split = np.array(self.split_data) 33 | self.split = self.split[self.split[:, 3] == self.class_idx].tolist() # stored nums are in str 34 | 35 | 36 | if self.subset == 'test': 37 | self.split_path_test = os.path.join(self.data_root, 'PFED5_test.csv') 38 | self.split_test = pd.read_csv(self.split_path_test) 39 | self.split_test = np.array(self.split_test) 40 | self.split_test = self.split_test[self.split_test[:, 3] == self.class_idx].tolist() 41 | 42 | 43 | if self.subset == 'test': 44 | self.dataset = self.split_test.copy() 45 | else: # sample 5 clips with different start frame for each video (Augmentation) 46 | self.dataset = self.split.copy()*5 47 | 48 | def __getitem__(self, index): 49 | sample_1 = self.dataset[index] 50 | id_v, label, cls, num_frame = sample_1[0], sample_1[1], int(sample_1[3]), int(sample_1[2]) 51 | patient_id, video_id = id_v.split('_')[-1], id_v[:11] 52 | id_path = os.path.join(self.data_root, patient_id, video_id) 53 | 54 | data = {} 55 | data['video'], frame_index = self.load_video(id_path, num_frame) 56 | if self.args.use_landmark: # only load the features corresponding to the sampled frames 57 | landmarks_path = os.path.join(self.landmarks_root, patient_id, video_id) 58 | data['landmark_heatmap'] = self.load_video(landmarks_path, num_frame, frame_index) 59 | data['final_score'] = label 60 | # data['video_id'] = id_path 61 | data['class'] = cls 62 | return data 63 | 64 | def __len__(self): 65 | return len(self.dataset) 66 | 67 | def load_short_clips(self, video_frames_list, clip_len, num_frames): 68 | video_clip = [] 69 | idx = 0 70 | start_frame = 1 71 | sample_rate = 1 72 | frame_index = [] 73 | for i in range(clip_len): 74 | cur_img_index = start_frame + idx * sample_rate 75 | # cur_img_path = os.path.join( 76 | # video_dir, 77 | # "img_" + "{:05}.jpg".format(start_frame + idx * sample_rate)) 78 | # print(cur_img_path) 79 | # img = cv2.imread(cur_img_path) 80 | # video_clip.append(img) 81 | frame_index.append(cur_img_index) 82 | if (start_frame + (idx + 1) * sample_rate) > num_frames: 83 | start_frame = 1 84 | idx = 0 85 | else: 86 | idx += 1 87 | imgs = [Image.open(video_frames_list[i-1]).convert('RGB') for i in frame_index] 88 | video_clip.extend(imgs) 89 | return video_clip, frame_index 90 | 91 | def load_long_clips(self, video_frames_list, clip_len, num_frames): 92 | video_clip = [] 93 | 94 | if self.subset == 'train': 95 | start_frame = random.randint(1, num_frames - clip_len) 96 | frame_index = [i for i in range(start_frame, start_frame + clip_len)] 97 | # print(num_frames, 'index:', frame_index) 98 | imgs = [Image.open(video_frames_list[i-1]).convert('RGB') for i in 99 | range(start_frame, start_frame + clip_len)] 100 | elif self.subset == 'test': # sample evenly spaced frames across the sequence for inference 101 | frame_partition = np.linspace(0, num_frames - 1, num=clip_len, dtype=np.int32) 102 | frame_index =[i+1 for i in frame_partition] 103 | # print(num_frames, 'index:', frame_index) 104 | imgs = [Image.open(video_frames_list[i]).convert('RGB') for i in frame_partition] 105 | else: 106 | assert f"subset must be train or test" 107 | video_clip.extend(imgs) 108 | return video_clip, frame_index 109 | 110 | def load_video(self, path, num_frame, frame_index=None): 111 | video_frames_list = sorted((glob.glob(os.path.join(path, '*.jpg')))) 112 | assert video_frames_list != None, f"check the video dir" 113 | assert len( 114 | video_frames_list) == num_frame, f"the number of imgs:{len(video_frames_list)} in {path} must be equal to num_frames:{num_frame}" 115 | if frame_index is None: 116 | # if clip length <= input length 117 | if len(video_frames_list) <= self.clip_len: 118 | video, frame_index = self.load_short_clips(video_frames_list, self.clip_len, num_frame) 119 | else: 120 | video, frame_index = self.load_long_clips(video_frames_list, self.clip_len, num_frame) 121 | return self.transform(video), frame_index 122 | else: 123 | video = [] 124 | imgs = [Image.open(video_frames_list[i-1]).convert('RGB') for i in frame_index] 125 | video.extend(imgs) 126 | return self.transform(video, use_landmark=True) 127 | def transform(self, video, use_landmark=False): 128 | trans = [] 129 | if use_landmark: 130 | if self.subset == 'train': 131 | trans = video_transforms.Compose([ 132 | video_transforms.RandomHorizontalFlip(), 133 | video_transforms.Resize((64, 64)), 134 | video_transforms.RandomCrop(56), 135 | volume_transforms.ClipToTensor(), 136 | ]) 137 | elif self.subset == 'test': 138 | trans = video_transforms.Compose([ 139 | video_transforms.Resize((64, 64)), 140 | video_transforms.CenterCrop(56), 141 | volume_transforms.ClipToTensor(), 142 | ]) 143 | else: 144 | if self.subset == 'train': 145 | trans = video_transforms.Compose([ 146 | video_transforms.RandomHorizontalFlip(), 147 | # video_transforms.Resize((256, 256)), 148 | video_transforms.RandomCrop(224), 149 | volume_transforms.ClipToTensor(), 150 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 151 | ]) 152 | elif self.subset == 'test': 153 | trans = video_transforms.Compose([ 154 | # video_transforms.Resize((256, 256)), 155 | video_transforms.CenterCrop(224), 156 | volume_transforms.ClipToTensor(), 157 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 158 | ]) 159 | return trans(video) 160 | 161 | def train_data_loader(args): 162 | train_data = Parkinson_former(args,subset='train') 163 | return train_data 164 | 165 | def test_data_loader(args): 166 | test_data = Parkinson_former(args,subset='test') 167 | return test_data 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /models/S_Former.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | class GELU(nn.Module): 9 | def forward(self, x): 10 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 11 | 12 | 13 | class Residual(nn.Module): 14 | def __init__(self, fn): 15 | super().__init__() 16 | self.fn = fn 17 | 18 | def forward(self, x, **kwargs): 19 | return self.fn(x, **kwargs) + x 20 | 21 | 22 | class PreNorm(nn.Module): 23 | def __init__(self, dim, fn): 24 | super().__init__() 25 | self.norm = nn.LayerNorm(dim) 26 | self.fn = fn 27 | 28 | def forward(self, x, **kwargs): 29 | return self.fn(self.norm(x), **kwargs) 30 | 31 | 32 | class FeedForward(nn.Module): 33 | def __init__(self, dim, hidden_dim, dropout=0.): 34 | super().__init__() 35 | self.net = nn.Sequential( 36 | nn.Linear(dim, hidden_dim), 37 | GELU(), 38 | nn.Dropout(dropout), 39 | nn.Linear(hidden_dim, dim), 40 | nn.Dropout(dropout) 41 | ) 42 | 43 | def forward(self, x): 44 | return self.net(x) 45 | 46 | 47 | class Attention(nn.Module): 48 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 49 | super().__init__() 50 | inner_dim = dim_head * heads 51 | project_out = not (heads == 1 and dim_head == dim) 52 | 53 | self.heads = heads 54 | self.scale = dim_head ** -0.5 55 | 56 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 57 | 58 | self.to_out = nn.Sequential( 59 | nn.Linear(inner_dim, dim), 60 | nn.Dropout(dropout) 61 | ) if project_out else nn.Identity() 62 | 63 | def forward(self, x, mask=None): 64 | b, n, _, h = *x.shape, self.heads 65 | qkv = self.to_qkv(x).chunk(3, dim=-1) 66 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 67 | 68 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 69 | mask_value = -torch.finfo(dots.dtype).max 70 | 71 | if mask is not None: 72 | mask = F.pad(mask.flatten(1), (1, 0), value=True) 73 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 74 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j') 75 | dots.masked_fill_(~mask, mask_value) 76 | del mask 77 | 78 | attn = dots.softmax(dim=-1) 79 | 80 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 81 | out = rearrange(out, 'b h n d -> b n (h d)') 82 | out = self.to_out(out) 83 | return out 84 | 85 | 86 | class Transformer(nn.Module): 87 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): 88 | super().__init__() 89 | self.layers = nn.ModuleList([]) 90 | for _ in range(depth): 91 | self.layers.append(nn.ModuleList([ 92 | Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))), 93 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))) 94 | ])) 95 | 96 | def forward(self, x, mask=None): 97 | for attn, ff in self.layers: 98 | x = attn(x, mask=mask) 99 | x = ff(x) 100 | return x 101 | 102 | 103 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 104 | """3x3 convolution with padding""" 105 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 106 | padding=dilation, groups=groups, bias=False, dilation=dilation) 107 | 108 | 109 | def conv1x1(in_planes, out_planes, stride=1): 110 | """1x1 convolution""" 111 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 112 | 113 | 114 | class BasicBlock(nn.Module): 115 | expansion = 1 116 | 117 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 118 | base_width=64, dilation=1, norm_layer=None): 119 | super(BasicBlock, self).__init__() 120 | if norm_layer is None: 121 | norm_layer = nn.BatchNorm2d 122 | if groups != 1 or base_width != 64: 123 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 124 | if dilation > 1: 125 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 126 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 127 | self.conv1 = conv3x3(inplanes, planes, stride) 128 | self.bn1 = norm_layer(planes) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.conv2 = conv3x3(planes, planes) 131 | self.bn2 = norm_layer(planes) 132 | self.downsample = downsample 133 | self.stride = stride 134 | 135 | def forward(self, x): 136 | identity = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | 145 | if self.downsample is not None: 146 | identity = self.downsample(x) 147 | 148 | out += identity 149 | out = self.relu(out) 150 | 151 | return out 152 | 153 | 154 | class ResNet(nn.Module): # S-Former after stage3 155 | 156 | def __init__(self, block, layers, zero_init_residual=False, 157 | groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, 158 | num_patches=14*14, dim=256, depth=1, heads=8, mlp_dim=512, dim_head=32, dropout=0.0): 159 | super(ResNet, self).__init__() 160 | if norm_layer is None: 161 | norm_layer = nn.BatchNorm2d 162 | self._norm_layer = norm_layer 163 | 164 | self.inplanes = 64 165 | self.dilation = 1 166 | if replace_stride_with_dilation is None: 167 | replace_stride_with_dilation = [False, False, False] 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError("replace_stride_with_dilation should be None or" 170 | " a 3-element tuple, got {}".format(replace_stride_with_dilation)) 171 | self.groups = groups 172 | self.base_width = width_per_group 173 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 174 | self.bn1 = norm_layer(self.inplanes) 175 | self.relu = nn.ReLU(inplace=True) 176 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 177 | self.layer1 = self._make_layer(block, 64, layers[0]) 178 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 179 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 180 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 182 | 183 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) 184 | self.spatial_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 185 | 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv2d): 188 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 189 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 190 | nn.init.constant_(m.weight, 1) 191 | nn.init.constant_(m.bias, 0) 192 | 193 | if zero_init_residual: 194 | for m in self.modules(): 195 | nn.init.constant_(m.bn2.weight, 0) 196 | 197 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 198 | norm_layer = self._norm_layer 199 | downsample = None 200 | previous_dilation = self.dilation 201 | if dilate: 202 | self.dilation *= stride 203 | stride = 1 204 | if stride != 1 or self.inplanes != planes * block.expansion: 205 | downsample = nn.Sequential( 206 | conv1x1(self.inplanes, planes * block.expansion, stride), 207 | norm_layer(planes * block.expansion)) 208 | layers = [] 209 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 210 | self.base_width, previous_dilation, norm_layer)) 211 | self.inplanes = planes * block.expansion 212 | for _ in range(1, blocks): 213 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, 214 | dilation=self.dilation, norm_layer=norm_layer)) 215 | 216 | return nn.Sequential(*layers) 217 | 218 | def forward(self, x): 219 | 220 | x = x.contiguous().view(-1, 3, 224, 224) 221 | 222 | x = self.conv1(x) 223 | x = self.bn1(x) 224 | x = self.relu(x) 225 | x = self.maxpool(x) # torch.Size([1, 64, 28, 28]) 226 | x = self.layer1(x) # torch.Size([1, 64, 28, 28]) 227 | x = self.layer2(x) # torch.Size([1, 128, 14, 14]) 228 | x = self.layer3(x) # torch.Size([1, 256, 7, 7]) 229 | b_l, c, h, w = x.shape 230 | x = x.reshape((b_l, c, h*w)) 231 | x = x.permute(0, 2, 1) 232 | b, n, _ = x.shape 233 | x = x + self.pos_embedding[:, :n] 234 | x = self.spatial_transformer(x) 235 | x = x.permute(0, 2, 1) 236 | x = x.reshape((b, c, h, w)) 237 | x = self.layer4(x) # torch.Size([1, 512, 4, 4]) 238 | x = self.avgpool(x) 239 | x = torch.flatten(x, 1) 240 | 241 | return x 242 | 243 | 244 | def spatial_transformer(): 245 | return ResNet(BasicBlock, [2, 2, 2, 2]) 246 | 247 | 248 | if __name__ == '__main__': 249 | img = torch.randn((1, 16, 3, 112, 112)) 250 | model = spatial_transformer() 251 | model(img) 252 | -------------------------------------------------------------------------------- /dataloader/video_transform.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | """ 48 | def __init__(self, is_flow=False): 49 | self.is_flow = is_flow 50 | 51 | def __call__(self, img_group, is_flow=False): 52 | v = random.random() 53 | if v < 0.5: 54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 55 | if self.is_flow: 56 | for i in range(0, len(ret), 2): 57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 58 | return ret 59 | else: 60 | return img_group 61 | 62 | 63 | class GroupNormalize(object): 64 | def __init__(self, mean, std): 65 | self.mean = mean 66 | self.std = std 67 | 68 | def __call__(self, tensor): 69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 70 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 71 | 72 | # TODO: make efficient 73 | for t, m, s in zip(tensor, rep_mean, rep_std): 74 | t.sub_(m).div_(s) 75 | 76 | return tensor 77 | 78 | 79 | class GroupScale(object): 80 | """ Rescales the input PIL.Image to the given 'size'. 81 | 'size' will be the size of the smaller edge. 82 | For example, if height > width, then image will be 83 | rescaled to (size * height / width, size) 84 | size: size of the smaller edge 85 | interpolation: Default: PIL.Image.BILINEAR 86 | """ 87 | 88 | def __init__(self, size, interpolation=Image.BILINEAR): 89 | self.worker = torchvision.transforms.Resize(size, interpolation) 90 | 91 | def __call__(self, img_group): 92 | return [self.worker(img) for img in img_group] 93 | 94 | 95 | class GroupOverSample(object): 96 | def __init__(self, crop_size, scale_size=None): 97 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 98 | 99 | if scale_size is not None: 100 | self.scale_worker = GroupScale(scale_size) 101 | else: 102 | self.scale_worker = None 103 | 104 | def __call__(self, img_group): 105 | 106 | if self.scale_worker is not None: 107 | img_group = self.scale_worker(img_group) 108 | 109 | image_w, image_h = img_group[0].size 110 | crop_w, crop_h = self.crop_size 111 | 112 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 113 | oversample_group = list() 114 | for o_w, o_h in offsets: 115 | normal_group = list() 116 | flip_group = list() 117 | for i, img in enumerate(img_group): 118 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 119 | normal_group.append(crop) 120 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 121 | 122 | if img.mode == 'L' and i % 2 == 0: 123 | flip_group.append(ImageOps.invert(flip_crop)) 124 | else: 125 | flip_group.append(flip_crop) 126 | 127 | oversample_group.extend(normal_group) 128 | oversample_group.extend(flip_group) 129 | return oversample_group 130 | 131 | 132 | class GroupMultiScaleCrop(object): 133 | 134 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 135 | self.scales = scales if scales is not None else [1, .875, .75, .66] 136 | self.max_distort = max_distort 137 | self.fix_crop = fix_crop 138 | self.more_fix_crop = more_fix_crop 139 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 140 | self.interpolation = Image.BILINEAR 141 | 142 | def __call__(self, img_group): 143 | 144 | im_size = img_group[0].size 145 | 146 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 147 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 148 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 149 | for img in crop_img_group] 150 | return ret_img_group 151 | 152 | def _sample_crop_size(self, im_size): 153 | image_w, image_h = im_size[0], im_size[1] 154 | 155 | # find a crop size 156 | base_size = min(image_w, image_h) 157 | crop_sizes = [int(base_size * x) for x in self.scales] 158 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 159 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 160 | 161 | pairs = [] 162 | for i, h in enumerate(crop_h): 163 | for j, w in enumerate(crop_w): 164 | if abs(i - j) <= self.max_distort: 165 | pairs.append((w, h)) 166 | 167 | crop_pair = random.choice(pairs) 168 | if not self.fix_crop: 169 | w_offset = random.randint(0, image_w - crop_pair[0]) 170 | h_offset = random.randint(0, image_h - crop_pair[1]) 171 | else: 172 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 173 | 174 | return crop_pair[0], crop_pair[1], w_offset, h_offset 175 | 176 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 177 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 178 | return random.choice(offsets) 179 | 180 | @staticmethod 181 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 182 | w_step = (image_w - crop_w) / 4 183 | h_step = (image_h - crop_h) / 4 184 | 185 | ret = list() 186 | ret.append((0, 0)) # upper left 187 | ret.append((4 * w_step, 0)) # upper right 188 | ret.append((0, 4 * h_step)) # lower left 189 | ret.append((4 * w_step, 4 * h_step)) # lower right 190 | ret.append((2 * w_step, 2 * h_step)) # center 191 | 192 | if more_fix_crop: 193 | ret.append((0, 2 * h_step)) # center left 194 | ret.append((4 * w_step, 2 * h_step)) # center right 195 | ret.append((2 * w_step, 4 * h_step)) # lower center 196 | ret.append((2 * w_step, 0 * h_step)) # upper center 197 | 198 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 199 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 200 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 201 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 202 | 203 | return ret 204 | 205 | 206 | class GroupResize(object): 207 | 208 | def __init__(self, size, interpolation=Image.BILINEAR): 209 | self.size = size 210 | self.interpolation = interpolation 211 | 212 | def __call__(self, img_group): 213 | out_group = list() 214 | for img in img_group: 215 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 216 | return out_group 217 | 218 | 219 | class GroupRandomSizedCrop(object): 220 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 221 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 222 | This is popularly used to train the Inception networks 223 | size: size of the smaller edge 224 | interpolation: Default: PIL.Image.BILINEAR 225 | """ 226 | def __init__(self, size, interpolation=Image.BILINEAR): 227 | self.size = size 228 | self.interpolation = interpolation 229 | 230 | def __call__(self, img_group): 231 | for attempt in range(10): 232 | area = img_group[0].size[0] * img_group[0].size[1] 233 | target_area = random.uniform(0.08, 1.0) * area 234 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 235 | 236 | w = int(round(math.sqrt(target_area * aspect_ratio))) 237 | h = int(round(math.sqrt(target_area / aspect_ratio))) 238 | 239 | if random.random() < 0.5: 240 | w, h = h, w 241 | 242 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 243 | x1 = random.randint(0, img_group[0].size[0] - w) 244 | y1 = random.randint(0, img_group[0].size[1] - h) 245 | found = True 246 | break 247 | else: 248 | found = False 249 | x1 = 0 250 | y1 = 0 251 | 252 | if found: 253 | out_group = list() 254 | for img in img_group: 255 | img = img.crop((x1, y1, x1 + w, y1 + h)) 256 | assert(img.size == (w, h)) 257 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 258 | return out_group 259 | else: 260 | # Fallback 261 | scale = GroupScale(self.size, interpolation=self.interpolation) 262 | crop = GroupRandomCrop(self.size) 263 | return crop(scale(img_group)) 264 | 265 | 266 | class Stack(object): 267 | 268 | def __init__(self, roll=False): 269 | self.roll = roll 270 | 271 | def __call__(self, img_group): 272 | if img_group[0].mode == 'L' or img_group[0].mode == 'F': 273 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 274 | elif img_group[0].mode == 'RGB': 275 | if self.roll: 276 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 277 | else: 278 | return np.concatenate(img_group, axis=2) 279 | 280 | 281 | class ToTorchFormatTensor(object): 282 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 283 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 284 | def __init__(self, div=True): 285 | self.div = div 286 | 287 | def __call__(self, pic): 288 | if isinstance(pic, np.ndarray): 289 | # handle numpy array 290 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 291 | else: 292 | # handle PIL Image 293 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 294 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 295 | # put it from HWC to CHW format 296 | # yikes, this transpose takes 80% of the loading time/CPU 297 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 298 | return img.to(torch.float32).div(255) if self.div else img.to(torch.float32) 299 | 300 | 301 | class IdentityTransform(object): 302 | 303 | def __call__(self, data): 304 | return data 305 | -------------------------------------------------------------------------------- /main_PFED5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import shutil 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | from einops import rearrange 15 | from scipy import stats 16 | import matplotlib.pyplot as plt 17 | from tqdm import tqdm 18 | 19 | from dataloader.Parkinson_landmarkheatmap import train_data_loader, test_data_loader 20 | from models.slowonly import ResNet3dSlowOnly 21 | from models.ST_Former import GenerateModel 22 | from dataloader.ibmse.balancedMSE import GAILoss, BMCLoss, FocalRLoss 23 | import numpy as np 24 | import datetime 25 | 26 | from models.evaluator import Evaluator 27 | from models.vit_decoder_two import decoder_fuser 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--class_idx',type=int,default=0,choices=[0, 1, 2, 3, 4], help='class idx in PD-5') 32 | parser.add_argument('--clip_len',type=int,default=80, help='input length') 33 | parser.add_argument('--data_root',type=str,default='/path/PFED5/frames', help='data path') 34 | parser.add_argument('--landmarks_root',type=str,default='/path/PFED5/landmarks_heatmap', help='landmarks path') 35 | parser.add_argument('--exp_name',type=str,help='path to save tensorboard curve',default='test') 36 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers') 37 | parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run') 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 39 | parser.add_argument('-b', '--batch_size', default=1, type=int, metavar='N') 40 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', dest='lr') 41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M') 42 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', dest='weight_decay') 43 | parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency') 44 | parser.add_argument('--resume', default=None, type=str, metavar='PATH', help='path to latest checkpoint') 45 | parser.add_argument('--gpu', type=str, default='0') 46 | parser.add_argument('--noise_sigma', type=float, default=1.) 47 | 48 | 49 | 50 | 51 | args = parser.parse_args() 52 | now = datetime.datetime.now() 53 | time_str = now.strftime("[%m-%d]-[%H:%M]-") 54 | output_log_root = './log/' +'{}/'.format(args.exp_name) 55 | if not os.path.exists(output_log_root): 56 | os.makedirs(output_log_root) 57 | output_ckpt_root = './checkpoint/' +'{}/'.format(args.exp_name) 58 | if not os.path.exists(output_ckpt_root): 59 | os.makedirs(output_ckpt_root) 60 | log_txt_path = os.path.join(output_log_root, 'class_{}-log.txt'.format(args.class_idx)) 61 | log_curve_path = os.path.join(output_log_root, 'class_{}-loss.png'.format(args.class_idx)) 62 | checkpoint_path = os.path.join(output_ckpt_root, 'class_{}-model.pth'.format( args.class_idx)) 63 | best_checkpoint_path = os.path.join(output_ckpt_root, 'class_{}-model_best.pth'.format(args.class_idx)) 64 | model_pretrained_path = './models/FormerDFER-DFEWset1-model_best.pth' 65 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 66 | 67 | def init_seed(args): 68 | """ 69 | Set random seed for torch and numpy. 70 | """ 71 | random.seed(args.seed) 72 | np.random.seed(args.seed) 73 | torch.manual_seed(args.seed) 74 | torch.cuda.manual_seed_all(args.seed) 75 | torch.backends.cudnn.deterministic = True 76 | torch.backends.cudnn.benchmark = False 77 | 78 | init_seed(args) 79 | 80 | def main(): 81 | # print configuration 82 | print('=' * 40) 83 | for k, v in vars(args).items(): 84 | print(f'{k}: {v}') 85 | print('=' * 40) 86 | 87 | best_acc = 0 88 | best_epoch = 0 89 | recorder = RecorderMeter(args.epochs) 90 | print('The training time: ' + now.strftime("%m-%d %H:%M")) 91 | 92 | # create model and load pre_trained parameters 93 | model = GenerateModel().cuda() 94 | SlowOnly = ResNet3dSlowOnly().cuda() 95 | evaluator = Evaluator(output_dim=1, model_type='MLP').cuda() 96 | decoder = decoder_fuser(dim=512, num_heads=8, num_layers=3, drop_rate=0.).cuda() 97 | 98 | from collections import OrderedDict 99 | 100 | new_state_dict = OrderedDict() 101 | pre_trained_dict = torch.load(model_pretrained_path) 102 | for k, v in pre_trained_dict['state_dict'].items(): 103 | name = k[7:] 104 | if 't_former.spatial_transformer' in name: 105 | new_k = name.replace('t_former.spatial_transformer', 't_former.temporal_transformer') 106 | new_state_dict[new_k] = v 107 | else: 108 | new_state_dict[name] = v 109 | 110 | new_state_dict.pop('fc.weight') 111 | new_state_dict.pop('fc.bias') 112 | model.load_state_dict(new_state_dict) 113 | SlowOnly.load_weights() ## (S) 114 | 115 | 116 | if len(args.gpu.split(',')) > 1: 117 | model = nn.DataParallel(model) 118 | evaluator = nn.DataParallel(evaluator) 119 | SlowOnly = nn.DataParallel(SlowOnly) 120 | decoder = nn.DataParallel(decoder) 121 | 122 | # define loss function (criterion) and optimizer 123 | criterion = BMCLoss(init_noise_sigma=args.noise_sigma) 124 | 125 | optimizer = torch.optim.SGD([{'params': model.parameters()}, 126 | {'params': evaluator.parameters()}, 127 | {'params': decoder.parameters()}, 128 | {'params': SlowOnly.parameters()}], 129 | args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 130 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) 131 | 132 | # optionally resume from a checkpoint 133 | if args.resume: 134 | if os.path.isfile(args.resume): 135 | print("=> loading checkpoint '{}'".format(args.resume)) 136 | checkpoint = torch.load(args.resume) 137 | args.start_epoch = checkpoint['epoch'] 138 | best_acc = checkpoint['best_acc'] 139 | recorder = checkpoint['recorder'] 140 | best_acc = best_acc.cuda() 141 | SlowOnly.load_state_dict(checkpoint['SlowOnly']) 142 | decoder.load_state_dict(checkpoint['decoder']) 143 | model.load_state_dict(checkpoint['model']) 144 | evaluator.load_state_dict(checkpoint['evaluator']) 145 | optimizer.load_state_dict(checkpoint['optimizer']) 146 | 147 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 148 | else: 149 | print("=> no checkpoint found at '{}'".format(args.resume)) 150 | cudnn.benchmark = True 151 | 152 | # Data loading code 153 | train_data = train_data_loader(args) 154 | test_data = test_data_loader(args) 155 | 156 | train_loader = torch.utils.data.DataLoader(train_data, 157 | batch_size=args.batch_size, 158 | shuffle=True, 159 | num_workers=args.workers, 160 | pin_memory=True) 161 | val_loader = torch.utils.data.DataLoader(test_data, 162 | batch_size=args.batch_size, 163 | shuffle=False, 164 | num_workers=args.workers, 165 | pin_memory=True) 166 | 167 | for epoch in tqdm(range(args.start_epoch, args.epochs)): 168 | inf = '********************' + str(epoch) + '********************' 169 | start_time = time.time() 170 | current_learning_rate = optimizer.state_dict()['param_groups'][0]['lr'] 171 | 172 | with open(log_txt_path, 'a') as f: 173 | f.write(inf + '\n') 174 | f.write('Current learning rate: ' + str(current_learning_rate) + '\n') 175 | 176 | print(inf) 177 | print('Current learning rate: ', current_learning_rate) 178 | 179 | # train for one epoch 180 | train_acc, train_los = train(train_loader, model, evaluator, SlowOnly, decoder, criterion, optimizer, epoch, args) 181 | # evaluate on validation set 182 | val_acc, val_los = validate(val_loader, model, evaluator, SlowOnly, decoder, criterion, args) 183 | 184 | scheduler.step() 185 | 186 | # remember best acc and save checkpoint 187 | is_best = val_acc > best_acc 188 | best_acc = max(val_acc, best_acc) 189 | if is_best: 190 | best_epoch = epoch 191 | save_checkpoint({'epoch': epoch + 1, 192 | 'model': model.state_dict(), 193 | 'evaluator': evaluator.state_dict(), 194 | 'SlowOnly': SlowOnly.state_dict(), 195 | 'decoder': decoder.state_dict(), 196 | 'best_acc': best_acc, 197 | 'optimizer': optimizer.state_dict(), 198 | 'recorder': recorder}, is_best) 199 | 200 | # print and save log 201 | epoch_time = time.time() - start_time 202 | 203 | 204 | recorder.update(epoch, train_los, val_los) 205 | recorder.plot_curve(log_curve_path) 206 | 207 | print('The best rho: {:.5f} in epoch {}'.format(best_acc, best_epoch)) 208 | print('An epoch time: {:.1f}s'.format(epoch_time)) 209 | with open(log_txt_path, 'a') as f: 210 | f.write('The best rho: {:.5f}' + str(best_acc) + 'in {}'.format(best_epoch) + '\n') 211 | f.write('An epoch time: {:.1f}s' + str(epoch_time) + '\n') 212 | 213 | 214 | def train(train_loader, model, evaluator, SlowOnly, decoder, criterion, optimizer, epoch, args): 215 | losses = AverageMeter('Loss', ':.4f') 216 | progress = ProgressMeter(len(train_loader), 217 | [losses], 218 | prefix="Epoch: [{}]".format(epoch)) 219 | 220 | # switch to train mode 221 | model.train() 222 | SlowOnly.train() 223 | decoder.train() 224 | evaluator.train() 225 | true_scores = [] 226 | pred_scores = [] 227 | 228 | for idx, data in enumerate(train_loader): 229 | 230 | true_scores.extend(data['final_score'].numpy()) 231 | videos = data['video'].cuda() 232 | heatmaps = data['landmark_heatmap'].cuda() 233 | b = data['final_score'].unsqueeze_(1).type(torch.FloatTensor).cuda() 234 | 235 | # compute output 236 | data_pack = torch.cat( 237 | [videos[:, :, i:i + 16] for i in range(0, args.clip_len, 16)]).cuda() # 5xN, c, 16, h, w 238 | outputs_v = model(data_pack).reshape(5, len(videos), 512).transpose(0, 1) # N, 5, featdim 239 | heatmap_pack = torch.cat( 240 | [heatmaps[:, :, i:i + 16] for i in range(0, args.clip_len, 16)]).cuda() # 5xN, c, 16, h, w 241 | outputs_l = SlowOnly(heatmap_pack).reshape(5, len(videos), 512).transpose(0, 1) # [b, 5, 512] 242 | 243 | output_lv_map = decoder(outputs_l, outputs_v) # q, v 244 | probs = evaluator(output_lv_map) 245 | probs = probs.mean(1) 246 | preds=probs 247 | loss = criterion(preds, b) 248 | losses.update(loss.item(), videos.size(0)) 249 | 250 | pred_scores.extend([i.item() for i in preds]) 251 | 252 | # compute gradient and do SGD step 253 | optimizer.zero_grad() 254 | loss.backward() 255 | optimizer.step() 256 | 257 | # print loss and accuracy 258 | if idx % args.print_freq == 0 or idx == len(train_loader)-1: 259 | progress.display(idx) 260 | 261 | rho_v, p_v = stats.spearmanr(pred_scores, true_scores) 262 | print('[train] EPOCH: %d, correlation_v: %.4f, lr: %.4f' 263 | % (epoch, rho_v, optimizer.param_groups[0]['lr'])) 264 | 265 | if epoch == 2 or epoch == args.epochs-1: 266 | print('pred_v scores', pred_scores) 267 | print('true_scores', true_scores) 268 | 269 | return rho_v, losses.avg 270 | 271 | 272 | def validate(val_loader, model, evaluator, SlowOnly, decoder, criterion, args): 273 | losses = AverageMeter('Loss', ':.4f') 274 | progress = ProgressMeter(len(val_loader), 275 | [losses], 276 | prefix='Test: ') 277 | 278 | # switch to evaluate mode 279 | model.eval() 280 | evaluator.eval() 281 | SlowOnly.eval() 282 | decoder.eval() 283 | 284 | true_scores = [] 285 | pred_scores = [] 286 | 287 | with torch.no_grad(): 288 | for idx, data in enumerate(val_loader): 289 | true_scores.extend(data['final_score'].numpy()) 290 | videos = data['video'].cuda() 291 | heatmaps = data['landmark_heatmap'].cuda() 292 | b = data['final_score'].unsqueeze_(1).type(torch.FloatTensor).cuda() 293 | 294 | # compute output 295 | data_pack = torch.cat( 296 | [videos[:, :, i:i + 16] for i in range(0, args.clip_len, 16)]).cuda() # 5xN, c, 16, h, w 297 | outputs_v = model(data_pack).reshape(5, len(videos), 512).transpose(0, 1) # N, 5, featdim 298 | heatmap_pack = torch.cat( 299 | [heatmaps[:, :, i:i + 16] for i in range(0, args.clip_len, 16)]).cuda() # 5xN, c, 16, h, w 300 | outputs_l = SlowOnly(heatmap_pack).reshape(5, len(videos), 512).transpose(0, 1) # [b, 5, 512] 301 | 302 | output_lv_map = decoder(outputs_l, outputs_v) # q, v 303 | probs = evaluator(output_lv_map) 304 | probs = probs.mean(1) 305 | preds = probs 306 | loss = criterion(preds, b) 307 | 308 | pred_scores.extend([i.item() for i in preds]) 309 | 310 | losses.update(loss.item(), videos.size(0)) 311 | 312 | 313 | if idx % args.print_freq == 0 or idx == len(val_loader)-1: 314 | progress.display(idx) 315 | 316 | rho_v, p_v = stats.spearmanr(pred_scores, true_scores) 317 | 318 | # TODO: this should also be done with the ProgressMeter 319 | print('Current Accuracy: {rho_v:.6f}'.format(rho_v=rho_v)) 320 | print('Predicted visual scores: ', pred_scores) 321 | print('True scores: ', true_scores) 322 | with open(log_txt_path, 'a') as f: 323 | f.write('Current Accuracy: {rho_v:.6f}'.format(rho_v=rho_v) + '\n') 324 | return rho_v, losses.avg 325 | 326 | 327 | def save_checkpoint(state, is_best): 328 | torch.save(state, checkpoint_path) 329 | if is_best: 330 | shutil.copyfile(checkpoint_path, best_checkpoint_path) 331 | 332 | 333 | class AverageMeter(object): 334 | """Computes and stores the average and current value""" 335 | def __init__(self, name, fmt=':f'): 336 | self.name = name 337 | self.fmt = fmt 338 | self.reset() 339 | 340 | def reset(self): 341 | self.val = 0 342 | self.avg = 0 343 | self.sum = 0 344 | self.count = 0 345 | 346 | def update(self, val, n=1): 347 | self.val = val 348 | self.sum += val * n 349 | self.count += n 350 | self.avg = self.sum / self.count 351 | 352 | def __str__(self): 353 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 354 | return fmtstr.format(**self.__dict__) 355 | 356 | 357 | class ProgressMeter(object): 358 | def __init__(self, num_batches, meters, prefix=""): 359 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 360 | self.meters = meters 361 | self.prefix = prefix 362 | 363 | def display(self, batch): 364 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 365 | entries += [str(meter) for meter in self.meters] 366 | print_txt = '\t'.join(entries) 367 | print(print_txt) 368 | with open(log_txt_path, 'a') as f: 369 | f.write(print_txt + '\n') 370 | 371 | def _get_batch_fmtstr(self, num_batches): 372 | num_digits = len(str(num_batches // 1)) 373 | fmt = '{:' + str(num_digits) + 'd}' 374 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 375 | 376 | 377 | # def accuracy(output, target, topk=(1,)): 378 | # """Computes the accuracy over the k top predictions for the specified values of k""" 379 | # with torch.no_grad(): 380 | # maxk = max(topk) 381 | # batch_size = target.size(0) 382 | # _, pred = output.topk(maxk, 1, True, True) 383 | # pred = pred.t() 384 | # correct = pred.eq(target.view(1, -1).expand_as(pred)) 385 | # res = [] 386 | # for k in topk: 387 | # correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 388 | # res.append(correct_k.mul_(100.0 / batch_size)) 389 | # return res 390 | 391 | 392 | class RecorderMeter(object): 393 | """Computes and stores the minimum loss value and its epoch index""" 394 | def __init__(self, total_epoch): 395 | self.reset(total_epoch) 396 | 397 | def reset(self, total_epoch): 398 | self.total_epoch = total_epoch 399 | self.current_epoch = 0 400 | self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] 401 | 402 | 403 | def update(self, idx, train_los, val_los): 404 | self.epoch_losses[idx, 0] = train_los * 50 405 | self.epoch_losses[idx, 1] = val_los * 50 406 | 407 | self.current_epoch = idx + 1 408 | 409 | def plot_curve(self, save_path): 410 | 411 | title = 'the accuracy/loss curve of train/val' 412 | dpi = 80 413 | width, height = 1600, 800 414 | legend_fontsize = 10 415 | figsize = width / float(dpi), height / float(dpi) 416 | 417 | fig = plt.figure(figsize=figsize) 418 | x_axis = np.array([i for i in range(self.total_epoch)]) # epochs 419 | y_axis = np.zeros(self.total_epoch) 420 | 421 | plt.xlim(0, self.total_epoch) 422 | plt.ylim(0, 100) 423 | interval_y = 5 424 | interval_x = 1 425 | plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) 426 | plt.yticks(np.arange(0, 100 + interval_y, interval_y)) 427 | plt.grid() 428 | plt.title(title, fontsize=20) 429 | plt.xlabel('the training epoch', fontsize=16) 430 | plt.ylabel('accuracy', fontsize=16) 431 | 432 | y_axis[:] = self.epoch_losses[:, 0] 433 | plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-loss-x50', lw=2) 434 | plt.legend(loc=4, fontsize=legend_fontsize) 435 | 436 | y_axis[:] = self.epoch_losses[:, 1] 437 | plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-loss-x50', lw=2) 438 | plt.legend(loc=4, fontsize=legend_fontsize) 439 | 440 | if save_path is not None: 441 | fig.savefig(save_path, dpi=dpi, bbox_inches='tight') 442 | plt.close(fig) 443 | 444 | 445 | if __name__ == '__main__': 446 | main() 447 | -------------------------------------------------------------------------------- /models/resnet3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import warnings 4 | from mmcv.cnn import ConvModule, build_activation_layer, constant_init, kaiming_init 5 | from mmcv.runner import _load_checkpoint, load_checkpoint 6 | from mmcv.utils import _BatchNorm 7 | from torch.nn.modules.utils import _ntuple, _triple 8 | 9 | # from ...utils import cache_checkpoint, get_root_logger 10 | # from ..builder import BACKBONES 11 | 12 | 13 | class BasicBlock3d(nn.Module): 14 | """BasicBlock 3d block for ResNet3D. 15 | Args: 16 | inplanes (int): Number of channels for the input in first conv3d layer. 17 | planes (int): Number of channels produced by some norm/conv3d layers. 18 | stride (tuple): Stride is a two element tuple (temporal, spatial). Default: (1, 1). 19 | downsample (nn.Module | None): Downsample layer. Default: None. 20 | inflate (bool): Whether to inflate kernel. Default: True. 21 | conv_cfg (dict): Config dict for convolution layer. Default: 'dict(type='Conv3d')'. 22 | norm_cfg (dict): Config for norm layers. required keys are 'type'. Default: 'dict(type='BN3d')'. 23 | act_cfg (dict): Config dict for activation layer. Default: 'dict(type='ReLU')'. 24 | """ 25 | expansion = 1 26 | 27 | def __init__(self, 28 | inplanes, 29 | planes, 30 | stride=(1, 1), 31 | downsample=None, 32 | inflate=True, 33 | inflate_style='3x3x3', 34 | conv_cfg=dict(type='Conv3d'), 35 | norm_cfg=dict(type='BN3d'), 36 | act_cfg=dict(type='ReLU')): 37 | super().__init__() 38 | assert inflate_style == '3x3x3' 39 | 40 | self.inplanes = inplanes 41 | self.planes = planes 42 | self.stride = stride 43 | self.inflate = inflate 44 | self.conv_cfg = conv_cfg 45 | self.norm_cfg = norm_cfg 46 | self.act_cfg = act_cfg 47 | 48 | self.conv1 = ConvModule( 49 | inplanes, 50 | planes, 51 | 3 if self.inflate else (1, 3, 3), 52 | stride=(self.stride[0], self.stride[1], self.stride[1]), 53 | padding=1 if self.inflate else (0, 1, 1), 54 | bias=False, 55 | conv_cfg=self.conv_cfg, 56 | norm_cfg=self.norm_cfg, 57 | act_cfg=self.act_cfg) 58 | 59 | self.conv2 = ConvModule( 60 | planes, 61 | planes * self.expansion, 62 | 3 if self.inflate else (1, 3, 3), 63 | stride=1, 64 | padding=1 if self.inflate else (0, 1, 1), 65 | bias=False, 66 | conv_cfg=self.conv_cfg, 67 | norm_cfg=self.norm_cfg, 68 | act_cfg=None) 69 | 70 | self.downsample = downsample 71 | self.relu = build_activation_layer(self.act_cfg) 72 | 73 | def forward(self, x): 74 | """Defines the computation performed at every call.""" 75 | 76 | def _inner_forward(x): 77 | """Forward wrapper for utilizing checkpoint.""" 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.conv2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out = out + identity 87 | return out 88 | 89 | out = _inner_forward(x) 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Bottleneck3d(nn.Module): 96 | """Bottleneck 3d block for ResNet3D. 97 | Args: 98 | inplanes (int): Number of channels for the input in first conv3d layer. 99 | planes (int): Number of channels produced by some norm/conv3d layers. 100 | stride (tuple): Stride is a two element tuple (temporal, spatial). Default: (1, 1). 101 | downsample (nn.Module | None): Downsample layer. Default: None. 102 | inflate (bool): Whether to inflate kernel. Default: True. 103 | inflate_style (str): '3x1x1' or '3x3x3'. which determines the kernel sizes and padding strides 104 | for conv1 and conv2 in each block. Default: '3x1x1'. 105 | conv_cfg (dict): Config dict for convolution layer. Default: 'dict(type='Conv3d')'. 106 | norm_cfg (dict): Config for norm layers. required keys are 'type'. Default: 'dict(type='BN3d')'. 107 | act_cfg (dict): Config dict for activation layer. Default: 'dict(type='ReLU')'. 108 | """ 109 | expansion = 4 110 | 111 | def __init__(self, 112 | inplanes, 113 | planes, 114 | stride=(1, 1), 115 | downsample=None, 116 | inflate=True, 117 | inflate_style='3x1x1', 118 | conv_cfg=dict(type='Conv3d'), 119 | norm_cfg=dict(type='BN3d'), 120 | act_cfg=dict(type='ReLU')): 121 | super().__init__() 122 | assert inflate_style in ['3x1x1', '3x3x3'] 123 | 124 | self.inplanes = inplanes 125 | self.planes = planes 126 | self.stride = stride 127 | self.inflate = inflate 128 | self.inflate_style = inflate_style 129 | self.norm_cfg = norm_cfg 130 | self.conv_cfg = conv_cfg 131 | self.act_cfg = act_cfg 132 | 133 | mode = 'no_inflate' if not self.inflate else self.inflate_style 134 | conv1_kernel_size = {'no_inflate': 1, '3x1x1': (3, 1, 1), '3x3x3': 1} 135 | conv1_padding = {'no_inflate': 0, '3x1x1': (1, 0, 0), '3x3x3': 0} 136 | conv2_kernel_size = {'no_inflate': (1, 3, 3), '3x1x1': (1, 3, 3), '3x3x3': 3} 137 | conv2_padding = {'no_inflate': (0, 1, 1), '3x1x1': (0, 1, 1), '3x3x3': 1} 138 | 139 | self.conv1 = ConvModule( 140 | inplanes, 141 | planes, 142 | conv1_kernel_size[mode], 143 | stride=1, 144 | padding=conv1_padding[mode], 145 | bias=False, 146 | conv_cfg=self.conv_cfg, 147 | norm_cfg=self.norm_cfg, 148 | act_cfg=self.act_cfg) 149 | 150 | self.conv2 = ConvModule( 151 | planes, 152 | planes, 153 | conv2_kernel_size[mode], 154 | stride=(self.stride[0], self.stride[1], self.stride[1]), 155 | padding=conv2_padding[mode], 156 | bias=False, 157 | conv_cfg=self.conv_cfg, 158 | norm_cfg=self.norm_cfg, 159 | act_cfg=self.act_cfg) 160 | 161 | self.conv3 = ConvModule( 162 | planes, 163 | planes * self.expansion, 164 | 1, 165 | bias=False, 166 | conv_cfg=self.conv_cfg, 167 | norm_cfg=self.norm_cfg, 168 | # No activation in the third ConvModule for bottleneck 169 | act_cfg=None) 170 | 171 | self.downsample = downsample 172 | self.relu = build_activation_layer(self.act_cfg) 173 | 174 | def forward(self, x): 175 | """Defines the computation performed at every call.""" 176 | 177 | def _inner_forward(x): 178 | """Forward wrapper for utilizing checkpoint.""" 179 | identity = x 180 | 181 | out = self.conv1(x) 182 | out = self.conv2(out) 183 | out = self.conv3(out) 184 | 185 | if self.downsample is not None: 186 | identity = self.downsample(x) 187 | 188 | out = out + identity 189 | return out 190 | 191 | out = _inner_forward(x) 192 | out = self.relu(out) 193 | 194 | return out 195 | 196 | 197 | 198 | class ResNet3d(nn.Module): 199 | """ResNet 3d backbone. 200 | Args: 201 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. Default: 50. 202 | pretrained (str | None): Name of pretrained model. 203 | stage_blocks (tuple | None): Set number of stages for each res layer. Default: None. 204 | pretrained2d (bool): Whether to load pretrained 2D model. Default: True. 205 | in_channels (int): Channel num of input features. Default: 3. 206 | base_channels (int): Channel num of stem output features. Default: 64. 207 | out_indices (tuple[int]): Indices of output feature. Default: (3, ). 208 | num_stages (int): Resnet stages. Default: 4. 209 | spatial_strides (tuple[int]): Spatial strides of residual blocks of each stage. Default: (1, 2, 2, 2). 210 | temporal_strides (tuple[int]): Temporal strides of residual blocks of each stage. Default: (1, 1, 1, 1). 211 | conv1_kernel (tuple[int]): Kernel size of the first conv layer. Default: (3, 7, 7). 212 | conv1_stride (tuple[int]): Stride of the first conv layer (temporal, spatial). Default: (1, 2). 213 | pool1_stride (tuple[int]): Stride of the first pooling layer (temporal, spatial). Default: (1, 2). 214 | advanced (bool): Flag indicating if an advanced design for downsample is adopted. Default: False. 215 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. Default: -1. 216 | inflate (tuple[int]): Inflate Dims of each block. Default: (1, 1, 1, 1). 217 | inflate_style (str): '3x1x1' or '3x3x3'. which determines the kernel sizes and padding strides 218 | for conv1 and conv2 in each block. Default: '3x1x1'. 219 | conv_cfg (dict): Config for conv layers. required keys are 'type'. Default: 'dict(type='Conv3d')'. 220 | norm_cfg (dict): Config for norm layers. required keys are 'type' and 'requires_grad'. 221 | Default: 'dict(type='BN3d', requires_grad=True)'. 222 | act_cfg (dict): Config dict for activation layer. Default: 'dict(type='ReLU', inplace=True)'. 223 | norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze running stats (mean and var). 224 | Default: False. 225 | zero_init_residual (bool): Whether to use zero initialization for residual block. Default: True. 226 | """ 227 | 228 | arch_settings = { 229 | 18: (BasicBlock3d, (2, 2, 2, 2)), 230 | 34: (BasicBlock3d, (3, 4, 6, 3)), 231 | 50: (Bottleneck3d, (3, 4, 6, 3)), 232 | 101: (Bottleneck3d, (3, 4, 23, 3)), 233 | 152: (Bottleneck3d, (3, 8, 36, 3)) 234 | } 235 | 236 | def __init__(self, 237 | depth=50, 238 | pretrained=None, 239 | stage_blocks=None, 240 | pretrained2d=True, 241 | in_channels=3, 242 | num_stages=4, 243 | base_channels=64, 244 | out_indices=(3, ), 245 | spatial_strides=(1, 2, 2, 2), 246 | temporal_strides=(1, 1, 1, 1), 247 | conv1_kernel=(3, 7, 7), 248 | conv1_stride=(1, 2), 249 | pool1_stride=(1, 2), 250 | advanced=False, 251 | frozen_stages=-1, 252 | inflate=(1, 1, 1, 1), 253 | inflate_style='3x1x1', 254 | conv_cfg=dict(type='Conv3d'), 255 | norm_cfg=dict(type='BN3d', requires_grad=True), 256 | act_cfg=dict(type='ReLU', inplace=True), 257 | norm_eval=False, 258 | zero_init_residual=True): 259 | super().__init__() 260 | if depth not in self.arch_settings: 261 | raise KeyError(f'invalid depth {depth} for resnet') 262 | self.depth = depth 263 | self.pretrained = pretrained 264 | self.pretrained2d = pretrained2d 265 | self.in_channels = in_channels 266 | self.base_channels = base_channels 267 | self.num_stages = num_stages 268 | assert 1 <= num_stages <= 4 269 | self.stage_blocks = stage_blocks 270 | self.out_indices = out_indices 271 | assert max(out_indices) < num_stages 272 | self.spatial_strides = spatial_strides 273 | self.temporal_strides = temporal_strides 274 | assert len(spatial_strides) == len(temporal_strides) == num_stages 275 | if self.stage_blocks is not None: 276 | assert len(self.stage_blocks) == num_stages 277 | 278 | self.conv1_kernel = conv1_kernel 279 | self.conv1_stride = conv1_stride 280 | self.pool1_stride = pool1_stride 281 | self.advanced = advanced 282 | self.frozen_stages = frozen_stages 283 | self.stage_inflations = _ntuple(num_stages)(inflate) 284 | self.inflate_style = inflate_style 285 | self.conv_cfg = conv_cfg 286 | self.norm_cfg = norm_cfg 287 | self.act_cfg = act_cfg 288 | self.norm_eval = norm_eval 289 | self.zero_init_residual = zero_init_residual 290 | 291 | self.block, stage_blocks = self.arch_settings[depth] 292 | 293 | if self.stage_blocks is None: 294 | self.stage_blocks = stage_blocks[:num_stages] 295 | 296 | self.inplanes = self.base_channels 297 | 298 | self._make_stem_layer() 299 | self.res_layers = [] 300 | # This field can be utilized by ResNet3dPathway, and has not side effect. 301 | lateral_inplanes = getattr(self, 'lateral_inplanes', [0, 0, 0, 0]) 302 | 303 | for i, num_blocks in enumerate(self.stage_blocks): 304 | spatial_stride = spatial_strides[i] 305 | temporal_stride = temporal_strides[i] 306 | planes = self.base_channels * 2**i 307 | res_layer = self.make_res_layer( 308 | self.block, 309 | self.inplanes + lateral_inplanes[i], 310 | planes, 311 | num_blocks, 312 | stride=(temporal_stride, spatial_stride), 313 | norm_cfg=self.norm_cfg, 314 | conv_cfg=self.conv_cfg, 315 | act_cfg=self.act_cfg, 316 | advanced=self.advanced, 317 | inflate=self.stage_inflations[i], 318 | inflate_style=self.inflate_style) 319 | self.inplanes = planes * self.block.expansion 320 | layer_name = f'layer{i + 1}' 321 | self.add_module(layer_name, res_layer) 322 | self.res_layers.append(layer_name) 323 | 324 | self.feat_dim = self.block.expansion * self.base_channels * 2 ** (len(self.stage_blocks) - 1) 325 | 326 | @staticmethod 327 | def make_res_layer(block, 328 | inplanes, 329 | planes, 330 | blocks, 331 | stride=(1, 1), 332 | inflate=1, 333 | inflate_style='3x1x1', 334 | advanced=False, 335 | norm_cfg=None, 336 | act_cfg=None, 337 | conv_cfg=None): 338 | """Build residual layer for ResNet3D. 339 | Args: 340 | block (nn.Module): Residual module to be built. 341 | inplanes (int): Number of channels for the input feature in each block. 342 | planes (int): Number of channels for the output feature in each block. 343 | blocks (int): Number of residual blocks. 344 | stride (tuple[int]): Stride (temporal, spatial) in residual and conv layers. Default: (1, 1). 345 | inflate (int | tuple[int]): Determine whether to inflate for each block. Default: 1. 346 | inflate_style (str): '3x1x1' or '3x3x3'. which determines the kernel sizes and padding strides 347 | for conv1 and conv2 in each block. Default: '3x1x1'. 348 | conv_cfg (dict | None): Config for norm layers. Default: None. 349 | norm_cfg (dict | None): Config for norm layers. Default: None. 350 | act_cfg (dict | None): Config for activate layers. Default: None. 351 | Returns: 352 | nn.Module: A residual layer for the given config. 353 | """ 354 | inflate = inflate if not isinstance(inflate, int) else (inflate, ) * blocks 355 | assert len(inflate) == blocks 356 | downsample = None 357 | if stride[1] != 1 or inplanes != planes * block.expansion: 358 | if advanced: 359 | conv = ConvModule( 360 | inplanes, 361 | planes * block.expansion, 362 | kernel_size=1, 363 | stride=1, 364 | bias=False, 365 | conv_cfg=conv_cfg, 366 | norm_cfg=norm_cfg, 367 | act_cfg=None) 368 | pool = nn.AvgPool3d( 369 | kernel_size=(stride[0], stride[1], stride[1]), 370 | stride=(stride[0], stride[1], stride[1]), 371 | ceil_mode=True) 372 | downsample = nn.Sequential(conv, pool) 373 | else: 374 | downsample = ConvModule( 375 | inplanes, 376 | planes * block.expansion, 377 | kernel_size=1, 378 | stride=(stride[0], stride[1], stride[1]), 379 | bias=False, 380 | conv_cfg=conv_cfg, 381 | norm_cfg=norm_cfg, 382 | act_cfg=None) 383 | 384 | layers = [] 385 | layers.append( 386 | block( 387 | inplanes, 388 | planes, 389 | stride=stride, 390 | downsample=downsample, 391 | inflate=(inflate[0] == 1), 392 | inflate_style=inflate_style, 393 | norm_cfg=norm_cfg, 394 | conv_cfg=conv_cfg, 395 | act_cfg=act_cfg)) 396 | inplanes = planes * block.expansion 397 | for i in range(1, blocks): 398 | layers.append( 399 | block( 400 | inplanes, 401 | planes, 402 | stride=(1, 1), 403 | inflate=(inflate[i] == 1), 404 | inflate_style=inflate_style, 405 | norm_cfg=norm_cfg, 406 | conv_cfg=conv_cfg, 407 | act_cfg=act_cfg)) 408 | 409 | return nn.Sequential(*layers) 410 | 411 | @staticmethod 412 | def _inflate_conv_params(conv3d, state_dict_2d, module_name_2d, inflated_param_names): 413 | """Inflate a conv module from 2d to 3d. 414 | Args: 415 | conv3d (nn.Module): The destination conv3d module. 416 | state_dict_2d (OrderedDict): The state dict of pretrained 2d model. 417 | module_name_2d (str): The name of corresponding conv module in the 2d model. 418 | inflated_param_names (list[str]): List of parameters that have been inflated. 419 | """ 420 | weight_2d_name = module_name_2d + '.weight' 421 | 422 | conv2d_weight = state_dict_2d[weight_2d_name] 423 | kernel_t = conv3d.weight.data.shape[2] 424 | 425 | new_weight = conv2d_weight.data.unsqueeze(2).expand_as(conv3d.weight) / kernel_t 426 | conv3d.weight.data.copy_(new_weight) 427 | inflated_param_names.append(weight_2d_name) 428 | 429 | if getattr(conv3d, 'bias') is not None: 430 | bias_2d_name = module_name_2d + '.bias' 431 | conv3d.bias.data.copy_(state_dict_2d[bias_2d_name]) 432 | inflated_param_names.append(bias_2d_name) 433 | 434 | @staticmethod 435 | def _inflate_bn_params(bn3d, state_dict_2d, module_name_2d, 436 | inflated_param_names): 437 | """Inflate a norm module from 2d to 3d. 438 | Args: 439 | bn3d (nn.Module): The destination bn3d module. 440 | state_dict_2d (OrderedDict): The state dict of pretrained 2d model. 441 | module_name_2d (str): The name of corresponding bn module in the 2d model. 442 | inflated_param_names (list[str]): List of parameters that have been inflated. 443 | """ 444 | for param_name, param in bn3d.named_parameters(): 445 | param_2d_name = f'{module_name_2d}.{param_name}' 446 | param_2d = state_dict_2d[param_2d_name] 447 | if param.data.shape != param_2d.shape: 448 | warnings.warn(f'The parameter of {module_name_2d} is not loaded due to incompatible shapes. ') 449 | return 450 | 451 | param.data.copy_(param_2d) 452 | inflated_param_names.append(param_2d_name) 453 | 454 | for param_name, param in bn3d.named_buffers(): 455 | param_2d_name = f'{module_name_2d}.{param_name}' 456 | # some buffers like num_batches_tracked may not exist in old checkpoints 457 | if param_2d_name in state_dict_2d: 458 | param_2d = state_dict_2d[param_2d_name] 459 | param.data.copy_(param_2d) 460 | inflated_param_names.append(param_2d_name) 461 | 462 | @staticmethod 463 | def _inflate_weights(self, logger): 464 | """Inflate the resnet2d parameters to resnet3d. 465 | The differences between resnet3d and resnet2d mainly lie in an extra 466 | axis of conv kernel. To utilize the pretrained parameters in 2d model, 467 | the weight of conv2d models should be inflated to fit in the shapes of 468 | the 3d counterpart. 469 | Args: 470 | logger (logging.Logger): The logger used to print 471 | debugging information. 472 | """ 473 | 474 | state_dict_r2d = _load_checkpoint(self.pretrained) 475 | if 'state_dict' in state_dict_r2d: 476 | state_dict_r2d = state_dict_r2d['state_dict'] 477 | 478 | inflated_param_names = [] 479 | for name, module in self.named_modules(): 480 | if isinstance(module, ConvModule): 481 | # we use a ConvModule to wrap conv+bn+relu layers, thus the name mapping is needed 482 | if 'downsample' in name: 483 | # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0 484 | original_conv_name = name + '.0' 485 | # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1 486 | original_bn_name = name + '.1' 487 | else: 488 | # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n} 489 | original_conv_name = name 490 | # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n} 491 | original_bn_name = name.replace('conv', 'bn') 492 | if original_conv_name + '.weight' not in state_dict_r2d: 493 | logger.warning(f'Module not exist in the state_dict_r2d: {original_conv_name}') 494 | else: 495 | shape_2d = state_dict_r2d[original_conv_name + '.weight'].shape 496 | shape_3d = module.conv.weight.data.shape 497 | if shape_2d != shape_3d[:2] + shape_3d[3:]: 498 | logger.warning(f'Weight shape mismatch for: {original_conv_name}: ' 499 | f'3d weight shape: {shape_3d}; 2d weight shape: {shape_2d}.') 500 | else: 501 | self._inflate_conv_params( 502 | module.conv, state_dict_r2d, original_conv_name, inflated_param_names 503 | ) 504 | 505 | if original_bn_name + '.weight' not in state_dict_r2d: 506 | logger.warning(f'Module not exist in the state_dict_r2d: {original_bn_name}') 507 | else: 508 | self._inflate_bn_params(module.bn, state_dict_r2d, original_bn_name, inflated_param_names) 509 | 510 | # check if any parameters in the 2d checkpoint are not loaded 511 | remaining_names = set(state_dict_r2d.keys()) - set(inflated_param_names) 512 | if remaining_names: 513 | logger.info(f'These parameters in the 2d checkpoint are not loaded: {remaining_names}') 514 | 515 | def inflate_weights(self, logger): 516 | self._inflate_weights(self, logger) 517 | 518 | def _make_stem_layer(self): 519 | """Construct the stem layers consists of a conv+norm+act module and a 520 | pooling layer.""" 521 | self.conv1 = ConvModule( 522 | self.in_channels, 523 | self.base_channels, 524 | kernel_size=self.conv1_kernel, 525 | stride=(self.conv1_stride[0], self.conv1_stride[1], self.conv1_stride[1]), 526 | padding=tuple([(k - 1) // 2 for k in _triple(self.conv1_kernel)]), 527 | bias=False, 528 | conv_cfg=self.conv_cfg, 529 | norm_cfg=self.norm_cfg, 530 | act_cfg=self.act_cfg) 531 | 532 | self.maxpool = nn.MaxPool3d( 533 | kernel_size=(1, 3, 3), 534 | stride=(self.pool1_stride[0], self.pool1_stride[1], self.pool1_stride[1]), 535 | padding=(0, 1, 1)) 536 | 537 | def _freeze_stages(self): 538 | """Prevent all the parameters from being optimized before 539 | 'self.frozen_stages'.""" 540 | if self.frozen_stages >= 0: 541 | self.conv1.eval() 542 | for param in self.conv1.parameters(): 543 | param.requires_grad = False 544 | 545 | for i in range(1, self.frozen_stages + 1): 546 | m = getattr(self, f'layer{i}') 547 | m.eval() 548 | for param in m.parameters(): 549 | param.requires_grad = False 550 | 551 | @staticmethod 552 | def _init_weights(self, pretrained=None): 553 | """Initiate the parameters either from existing checkpoint or from 554 | scratch. 555 | Args: 556 | pretrained (str | None): The path of the pretrained weight. Will override the original 'pretrained' if set. 557 | The arg is added to be compatible with mmdet. Default: None. 558 | """ 559 | for m in self.modules(): 560 | if isinstance(m, nn.Conv3d): 561 | kaiming_init(m) 562 | elif isinstance(m, _BatchNorm): 563 | constant_init(m, 1) 564 | 565 | if self.zero_init_residual: 566 | for m in self.modules(): 567 | if isinstance(m, Bottleneck3d): 568 | constant_init(m.conv3.bn, 0) 569 | elif isinstance(m, BasicBlock3d): 570 | constant_init(m.conv2.bn, 0) 571 | 572 | # if pretrained: 573 | # self.pretrained = pretrained 574 | # if isinstance(self.pretrained, str): 575 | # logger = get_root_logger() 576 | # logger.info(f'load model from: {self.pretrained}') 577 | # 578 | # if self.pretrained2d: 579 | # self.inflate_weights(logger) 580 | # else: 581 | # self.pretrained = cache_checkpoint(self.pretrained) 582 | # load_checkpoint(self, self.pretrained, strict=False, logger=logger) 583 | 584 | def init_weights(self, pretrained=None): 585 | self._init_weights(self, pretrained) 586 | 587 | def forward(self, x): 588 | """Defines the computation performed at every call. 589 | Args: 590 | x (torch.Tensor): The input data. 591 | Returns: 592 | torch.Tensor: The feature of the input 593 | samples extracted by the backbone. 594 | """ 595 | x = self.conv1(x) 596 | x = self.maxpool(x) 597 | outs = [] 598 | for i, layer_name in enumerate(self.res_layers): 599 | res_layer = getattr(self, layer_name) 600 | x = res_layer(x) 601 | if i in self.out_indices: 602 | outs.append(x) 603 | if len(outs) == 1: 604 | return outs[0] 605 | 606 | return tuple(outs) 607 | 608 | def train(self, mode=True): 609 | """Set the optimization status when training.""" 610 | super().train(mode) 611 | self._freeze_stages() 612 | if mode and self.norm_eval: 613 | for m in self.modules(): 614 | if isinstance(m, _BatchNorm): 615 | m.eval() --------------------------------------------------------------------------------