├── requirements.txt ├── model ├── compare │ ├── flops.py │ ├── attention_fn.py │ ├── predict.py │ ├── concat.py │ ├── train.py │ └── model.py ├── loss.py ├── model.py ├── utils.py └── AVoiD.py ├── README.md ├── data_processing ├── my_dataset.py ├── preprocess.py └── transfer.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | numpy>=1.20.1 3 | torchaudio>=0.9.0 4 | torchvision>=0.10.0 5 | tqdm>=4.41.1 6 | pytorch_lightning==1.7 7 | -------------------------------------------------------------------------------- /model/compare/flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fvcore.nn import FlopCountAnalysis 3 | 4 | from model import Attention 5 | 6 | 7 | def main(): 8 | # Self-Attention 9 | a1 = Attention(dim=512, num_heads=1) 10 | a1.proj = torch.nn.Identity() # remove Wo 11 | 12 | # Multi-Head Attention 13 | a2 = Attention(dim=512, num_heads=8) 14 | 15 | # [batch_size, num_tokens, total_embed_dim] 16 | t = (torch.rand(32, 1024, 512),) 17 | 18 | flops1 = FlopCountAnalysis(a1, t) 19 | print("Self-Attention FLOPs:", flops1.total()) 20 | 21 | flops2 = FlopCountAnalysis(a2, t) 22 | print("Multi-Head Attention FLOPs:", flops2.total()) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | 28 | -------------------------------------------------------------------------------- /model/compare/attention_fn.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def sum_attention(nnet, query, value, mask=None, dropout=None, mode='1D'): 8 | if mode == '2D': 9 | batch, dim = query.size(0), query.size(1) 10 | query = query.permute(0, 2, 3, 1).view(batch, -1, dim) 11 | value = value.permute(0, 2, 3, 1).view(batch, -1, dim) 12 | mask = mask.view(batch, 1, -1) 13 | 14 | scores = nnet(query).transpose(-2, -1) 15 | if mask is not None: 16 | scores.data.masked_fill_(mask.eq(0), -65504.0) 17 | 18 | p_attn = F.softmax(scores, dim=-1) 19 | if dropout is not None: 20 | p_attn = dropout(p_attn) 21 | weighted = torch.matmul(p_attn, value) 22 | 23 | return weighted, p_attn 24 | 25 | 26 | def qkv_attention(query, key, value, mask=None, dropout=None): # [2,8,197,96] 27 | d_k = query.size(-1) 28 | scores = torch.matmul(query, key.transpose(-2, -1)) / sqrt(d_k) 29 | 30 | p_attn = F.softmax(scores, dim=-1) 31 | if dropout is not None: 32 | p_attn = dropout(p_attn) 33 | 34 | return torch.matmul(p_attn, value), p_attn 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AVoiD-DF 2 | ## This repository contains code and models for our paper 3 | > AVoiD-DF: Audio-Visual Joint Learning for Detecting Deepfake 4 | > > In this paper, we propose an Audio-Visual Joint Learning for Detecting Deepfake (AVoiD-DF), which exploits audio-visual inconsistency for multi-modal forgery detection. 5 | 6 | ### Clone the repo: 7 | ``` 8 | git clone https://github.com/SYSU-DISG/AVoiD-DF.git 9 | ``` 10 | 11 | ### Download the datasets from the following link: 12 | ``` 13 | https://pan.baidu.com/s/1MckHs-H57jTma5v0o6XYMA 14 | ``` 15 | (Sorry, due to protocol restrictions, the datasets and weights are not publicly available now.) 16 | 17 | We use Python 3.6. Install requirements by running: 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | We create our own dataloader found in 'data_processing/my_dataset', you can modify it to train with your own data. 23 | 24 | ### Training 25 | 26 | You can do training on your own data by a simple command: 27 | ``` 28 | python train.py 29 | ``` 30 | ### Note 31 | 32 | Unfortunately, due to protocol restrictions we cannot release the complete source code and models. We have open sourced some of the modules and our training codes are based on the released code of ViT. 33 | 34 | ### Acknowledgements 35 | 36 | Our work is based on the official version of AVoiD-DF, and some of our codes refer to ViT (Vision Transformer). Thanks for sharing! 37 | 38 | ### Citation 39 | 40 | If you find our repo helpful to your research, please cite our paper, thanks. 41 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module 4 | import numpy as np 5 | 6 | class CroLoss(Module): 7 | 8 | def CroLoss(Y, P): 9 | Y = np.float_(Y) 10 | P = np.float_(P) 11 | return -np.sum(Y * np.log(P) + (1 - Y) * np.log(1 - P)) 12 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor): 13 | loss = [] 14 | for i, frame in enumerate(n_frames): 15 | loss.append(self.loss_fn(pred[i, :, :frame], true[i, :, :frame])) 16 | return torch.mean(torch.stack(loss)) 17 | 18 | 19 | class AmmLoss(Module): 20 | 21 | def __init__(self, gamma=2,alpha=0.25): 22 | super(AmmLoss, self).__init__() 23 | self.gamma = gamma 24 | self.alpha=alpha 25 | def forward(self, input, target): 26 | pt=torch.softmax(input,dim=1) 27 | p=pt[:,1] 28 | loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-(1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p)) 29 | return loss.mean() 30 | 31 | 32 | class ConLoss(Module): 33 | 34 | def __init__(self, margin: float = 0.99): 35 | super().__init__() 36 | self.margin = margin 37 | def forward(self, pred1: Tensor, pred2: Tensor, labels: Tensor, n_frames: Tensor): 38 | loss = [] 39 | for i, frame in enumerate(n_frames): 40 | d = torch.dist(pred1[i, :, :frame], pred2[i, :, :frame], 2) 41 | if labels[i]: 42 | loss.append(d ** 2) 43 | else: 44 | loss.append(torch.clip(self.margin - d, min=0.) ** 2) 45 | return torch.mean(torch.stack(loss)) 46 | -------------------------------------------------------------------------------- /data_processing/my_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | class MyDataSet(Dataset): 6 | def __init__(self, images_path: list, images_class: list, audio_path: list, audio_class: list, transform=None): 7 | self.images_path = images_path 8 | self.images_class = images_class 9 | self.audio_path = audio_path 10 | self.audio_class = audio_class 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return len(self.images_path) 15 | 16 | def __getitem__(self, item): 17 | img = Image.open(self.images_path[item]) 18 | audio = Image.open(self.audio_path[item]) 19 | if img.mode != 'RGB': 20 | raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) 21 | if audio.mode != 'RGB': 22 | raise ValueError("image: {} isn't RGB mode.".format(self.audio_path[item])) 23 | 24 | label_images = self.images_class[item] 25 | label_audio = self.audio_class[item] 26 | 27 | if self.transform is not None: 28 | img = self.transform(img) 29 | audio = self.transform(audio) 30 | 31 | return img, label_images, audio, label_audio 32 | 33 | @staticmethod 34 | def collate_fn(batch): 35 | images, labels_images, audio, labels_audio = tuple(zip(*batch)) 36 | 37 | images = torch.stack(images, dim=0) 38 | audio = torch.stack(audio, dim=0) 39 | labels_images = torch.as_tensor(labels_images) 40 | labels_audio = torch.as_tensor(labels_audio) 41 | return images, labels_images, audio, labels_audio -------------------------------------------------------------------------------- /model/compare/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms 7 | import matplotlib.pyplot as plt 8 | 9 | from model import vit_base_patch16_224_in21k as create_model 10 | 11 | 12 | def main(): 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | data_transform = transforms.Compose( 16 | [transforms.Resize(256), 17 | transforms.CenterCrop(224), 18 | transforms.ToTensor(), 19 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 20 | 21 | # load image 22 | img_path = "" 23 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 24 | img = Image.open(img_path) 25 | plt.imshow(img) 26 | # [N, C, H, W] 27 | img = data_transform(img) 28 | # expand batch dimension 29 | img = torch.unsqueeze(img, dim=0) 30 | 31 | # read class_indict 32 | json_path = '' 33 | assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) 34 | 35 | json_file = open(json_path, "r") 36 | class_indict = json.load(json_file) 37 | 38 | # create model 39 | model = create_model(num_classes=5, has_logits=False).to(device) 40 | # load model weights 41 | model_weight_path = "./weights/model-9.pth" 42 | model.load_state_dict(torch.load(model_weight_path, map_location=device)) 43 | model.eval() 44 | with torch.no_grad(): 45 | # predict class 46 | output = torch.squeeze(model(img.to(device))).cpu() 47 | predict = torch.softmax(output, dim=0) 48 | predict_cla = torch.argmax(predict).numpy() 49 | 50 | print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 51 | predict[predict_cla].numpy()) 52 | plt.title(print_res) 53 | print(print_res) 54 | plt.show() 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /data_processing/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import glob 5 | 6 | 7 | def read_split_data(video_root: str, audio_root: str, val_rate: float = 0.2): 8 | random.seed(0) 9 | 10 | assert os.path.exists(video_root), "dataset root: {} does not exist.".format(video_root) 11 | assert os.path.exists(audio_root), "dataset root: {} does not exist.".format(audio_root) 12 | # Get audio and video paths 13 | video_class = [cla for cla in os.listdir(video_root) if 14 | os.path.isdir(os.path.join(video_root, cla))] 15 | audio_class = [cla for cla in os.listdir(audio_root) if 16 | os.path.isdir(os.path.join(audio_root, cla))] 17 | video_class.sort() 18 | audio_class.sort() 19 | class_indices = dict((k, v) for v, k in enumerate(video_class)) 20 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) 21 | 22 | with open('class_indices.json', 'w') as json_file: 23 | json_file.write(json_str) 24 | 25 | train_images_path = [] 26 | train_images_label = [] 27 | train_audio_path = [] 28 | train_audio_label = [] 29 | val_images_path = [] 30 | val_images_label = [] 31 | val_audio_path = [] 32 | val_audio_label = [] 33 | 34 | supported = [".jpg", ".JPG", ".png", ".PNG"] 35 | for cla in video_class: 36 | video_root2 = os.path.join(video_root, cla) 37 | audio_root2 = os.path.join(audio_root, cla) 38 | video_cla2_class = [cla2 for cla2 in os.listdir(video_root2) if 39 | os.path.isdir(os.path.join(video_root2 + '/', cla2))] 40 | for cla2 in video_cla2_class: 41 | video_cla_path = os.path.join(video_root2 + '/', cla2) 42 | video_images = [os.path.join(video_cla_path + '/', i) for i in os.listdir(video_cla_path) 43 | if os.path.splitext(i)[-1] in supported] 44 | audio_images = glob.glob('{}/{}**'.format(audio_root2, cla2)) 45 | for i in audio_images: 46 | if 'chunk' not in i: 47 | audio_images.remove(i) 48 | audio_images = sorted(audio_images, key=lambda i: int(re.findall(r'\d+', i)[-1])) 49 | 50 | fl = len(video_images) 51 | al = len(audio_images) 52 | if fl > al: 53 | video_images = video_images[:al] 54 | elif fl < al: 55 | audio_images = audio_images[:fl] 56 | 57 | image_class = class_indices[cla] 58 | 59 | val_f_path = random.sample(video_images, k=int(len(video_images) * val_rate)) 60 | val_a_path = [] 61 | for i in val_f_path: 62 | index = video_images.index(i) 63 | val_a_path.append(audio_images[index]) 64 | 65 | for img_path in video_images: 66 | if img_path in val_f_path: 67 | val_images_path.append(img_path) 68 | val_images_label.append(image_class) 69 | else: 70 | train_images_path.append(img_path) 71 | train_images_label.append(image_class) 72 | 73 | for img_path in audio_images: 74 | if img_path in val_a_path: 75 | val_audio_path.append(img_path) 76 | val_audio_label.append(image_class) 77 | else: 78 | train_audio_path.append(img_path) 79 | train_audio_label.append(image_class) 80 | 81 | print("{} images were found in the dataset.".format( 82 | len(train_images_path) + len(val_images_path) + len(train_audio_path) + len(val_audio_path))) 83 | print("{} images for training.".format(len(train_images_path) + len(train_audio_path))) 84 | print("{} images for validation.".format(len(val_images_path) + len(val_audio_path))) 85 | return train_images_path, train_images_label, train_audio_path, train_audio_label, val_images_path, val_images_label, val_audio_path, val_audio_label 86 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from select_backbone import select_resnet 5 | import math 6 | 7 | class arche(nn.Module): 8 | def __init__(self, img_dim, network='resnet50', num_layers_in_fc_layers = 1024, dropout=0.5): 9 | super(arche, self).__init__(); 10 | 11 | self.__nFeatures__ = 24; 12 | self.__nChs__ = 32; 13 | self.__midChs__ = 32; 14 | 15 | self.netcnnaud = nn.Sequential( 16 | nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)), 17 | nn.BatchNorm2d(64), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(kernel_size=(1,1), stride=(1,1)), 20 | 21 | nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)), 22 | nn.BatchNorm2d(192), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)), 25 | 26 | nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)), 27 | nn.BatchNorm2d(384), 28 | nn.ReLU(inplace=True), 29 | 30 | nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)), 31 | nn.BatchNorm2d(256), 32 | nn.ReLU(inplace=True), 33 | 34 | nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)), 35 | nn.BatchNorm2d(256), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), 38 | 39 | nn.Conv2d(256, 512, kernel_size=(5,4), padding=(0,0)), 40 | nn.BatchNorm2d(512), 41 | nn.ReLU(), 42 | ); 43 | 44 | self.netfcaud = nn.Sequential( 45 | nn.Linear(512*21, 4096), 46 | nn.BatchNorm1d(4096), 47 | nn.ReLU(), 48 | nn.Linear(4096, num_layers_in_fc_layers), 49 | ); 50 | 51 | self.netcnnlip, self.param = select_resnet(network, track_running_stats=False); 52 | self.last_duration = int(math.ceil(30 / 4)) 53 | self.last_size = int(math.ceil(img_dim / 32)) 54 | 55 | self.netfclip = nn.Sequential( 56 | nn.Linear(self.param['feature_size']*self.last_size*self.last_size, 4096), 57 | nn.BatchNorm1d(4096), 58 | nn.ReLU(), 59 | nn.Linear(4096, num_layers_in_fc_layers), 60 | ); 61 | 62 | self.final_bn_lip = nn.BatchNorm1d(num_layers_in_fc_layers) 63 | self.final_bn_lip.weight.data.fill_(1) 64 | self.final_bn_lip.bias.data.zero_() 65 | 66 | self.final_fc_lip = nn.Sequential(nn.Dropout(dropout), nn.Linear(num_layers_in_fc_layers, 2)) 67 | self._initialize_weights(self.final_fc_lip) 68 | 69 | self.final_bn_aud = nn.BatchNorm1d(num_layers_in_fc_layers) 70 | self.final_bn_aud.weight.data.fill_(1) 71 | self.final_bn_aud.bias.data.zero_() 72 | 73 | self.final_fc_aud = nn.Sequential(nn.Dropout(dropout), nn.Linear(num_layers_in_fc_layers, 2)) 74 | self._initialize_weights(self.final_fc_aud) 75 | 76 | 77 | self._initialize_weights(self.netcnnaud) 78 | self._initialize_weights(self.netfcaud) 79 | self._initialize_weights(self.netfclip) 80 | 81 | def forward_aud(self, x): 82 | (B, N, N, H, W) = x.shape 83 | x = x.view(B*N, N, H, W) 84 | mid = self.netcnnaud(x); 85 | mid = mid.view((mid.size()[0], -1)); 86 | out = self.netfcaud(mid); 87 | return out; 88 | 89 | def forward_lip(self, x): 90 | (B, N, C, NF, H, W) = x.shape 91 | x = x.view(B*N, C, NF, H, W) 92 | feature = self.netcnnlip(x); 93 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=(1, 1, 1)) 94 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) 95 | feature = feature.view((feature.size()[0], -1)); 96 | out = self.netfclip(feature); 97 | return out; 98 | 99 | def final_classification_lip(self,feature): 100 | feature = self.final_bn_lip(feature) 101 | output = self.final_fc_lip(feature) 102 | return output 103 | 104 | def final_classification_aud(self,feature): 105 | feature = self.final_bn_aud(feature) 106 | output = self.final_fc_aud(feature) 107 | return output 108 | 109 | def forward_lipfeat(self, x): 110 | mid = self.netcnnlip(x); 111 | out = mid.view((mid.size()[0], -1)); 112 | return out; 113 | 114 | def _initialize_weights(self, module): 115 | for m in module: 116 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.ReLU) or isinstance(m,nn.MaxPool2d) or isinstance(m,nn.Dropout): 120 | pass 121 | else: 122 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 123 | if m.bias is not None: m.bias.data.zero_() -------------------------------------------------------------------------------- /model/compare/concat.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class PatchEmbed(nn.Module): 6 | """ 7 | 2D Image to Patch Embedding 8 | """ 9 | 10 | def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768): 11 | super().__init__() 12 | img_size = (img_size, img_size) 13 | patch_size = (patch_size, patch_size) 14 | self.img_size = img_size 15 | self.patch_size = patch_size 16 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 17 | self.num_patches = self.grid_size[0] * self.grid_size[1] 18 | 19 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 20 | 21 | def forward(self, x, y): 22 | B, C, H, W = x.shape 23 | assert H == self.img_size[0] and W == self.img_size[1], \ 24 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 25 | 26 | x = self.proj(x).flatten(2) # [B,768,196] 27 | y = self.proj(y).flatten(2) # [B,768,196] 28 | return x, y 29 | 30 | class ZSLNet(nn.Module): 31 | def __init__(self, device='cuda:0', norm_layer=None, embed_dim=768): 32 | super(ZSLNet, self).__init__() 33 | self.device = device 34 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 35 | 36 | def forward(self, x, y): 37 | return self.forward_ranking(x, y) 38 | 39 | def forward_ranking(self, x, y): 40 | loss_cos = torch.zeros(1).to(self.device) # align 41 | images_loss_mapping_consistency = torch.zeros(1).to(self.device) # con 42 | audio_loss_mapping_consistency = torch.zeros(1).to(self.device) # con 43 | 44 | fc_v = nn.Sequential( 45 | nn.Linear(196, 512), 46 | nn.ReLU(), 47 | nn.Linear(512, 256), 48 | nn.ReLU(), 49 | nn.Linear(256, 98) 50 | ) 51 | 52 | fc_a = nn.Sequential( 53 | nn.Linear(196, 512), 54 | nn.ReLU(), 55 | nn.Linear(512, 256), 56 | nn.ReLU(), 57 | nn.Linear(256, 98) 58 | ) 59 | 60 | fc_v = fc_v.to(self.device) 61 | fc_a = fc_a.to(self.device) 62 | 63 | visual_feats = fc_v(x) 64 | audio_feats = fc_a(y) 65 | 66 | if True: 67 | images_mapped_sim = self.sim_score(visual_feats, visual_feats.detach()) 68 | images_orig_sim = self.sim_score(x, x) 69 | images_loss_mapping_consistency = torch.abs(images_orig_sim - images_mapped_sim).mean() 70 | 71 | if True: 72 | audio_mapped_sim = self.sim_score(audio_feats, audio_feats.detach()) 73 | audio_orig_sim = self.sim_score(y, y) 74 | audio_loss_mapping_consistency = torch.abs(audio_orig_sim - audio_mapped_sim).mean() 75 | 76 | 77 | loss_cos = self.CosineLoss(visual_feats, audio_feats) 78 | 79 | a = torch.ones(1).to(self.device) 80 | b = torch.ones(1).to(self.device) 81 | c = torch.ones(1).to(self.device) 82 | 83 | feats = torch.cat((visual_feats, audio_feats), dim=2) # [b,768,98]+[b,768,98]=[b,768,196] 84 | feats = feats.transpose(1, 2) # [b,768,196] 85 | feats = self.norm(feats) 86 | # feats=visual_feats+audio_feats 87 | return loss_cos, feats # [b,c,hw] 88 | 89 | 90 | def sim_score(self, a, b): 91 | a_norm = a / (1e-6 + a.norm(dim=-1)).unsqueeze(2) 92 | b_norm = b / (1e-6 + b.norm(dim=-1)).unsqueeze(2) 93 | score = (torch.matmul(a_norm, b_norm.transpose(1, 2))) 94 | return score 95 | 96 | def CosineLoss(self, t_emb, v_emb): 97 | a_norm = v_emb / (1e-6 + v_emb.norm(dim=-1)).unsqueeze(2) 98 | b_norm = t_emb / (1e-6 + t_emb.norm(dim=-1)).unsqueeze(2) 99 | loss = 1 - torch.mean(torch.diagonal(torch.matmul(a_norm, b_norm.transpose(1, 2)), 0)) 100 | return loss 101 | 102 | 103 | class castModel(nn.Module): 104 | def __init__(self, Embed_layer=PatchEmbed, Cast=ZSLNet): 105 | super(castModel, self).__init__() 106 | self.PatchEmbed = Embed_layer() 107 | self.ZSLNet = Cast() 108 | 109 | def forward(self, x, y): 110 | a, b = self.PatchEmbed(x, y) 111 | a, castLoss = self.ZSLNet(a, b) 112 | return a, castLoss 113 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import torch 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torchvision import transforms 10 | 11 | from data_processing.my_dataset import MyDataSet 12 | from model.AVoiD import AVoiD_mm 13 | from utils.utils import train_one_epoch, evaluate 14 | from data_processing.preprocess import read_split_data 15 | 16 | 17 | def main(args): 18 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 19 | random.seed(0) 20 | 21 | if os.path.exists("./weights") is False: 22 | os.makedirs("./weights") 23 | 24 | tb_writer = SummaryWriter() 25 | 26 | train_images_path, train_images_label, train_audio_path, train_audio_label, val_images_path, val_images_label, val_audio_path, val_audio_label = read_split_data( 27 | args.video_data_path, args.audio_data_path) 28 | 29 | # preprocessing 30 | data_transform = { 31 | "train": transforms.Compose([transforms.Resize([224, 224]), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 34 | "val": transforms.Compose([transforms.Resize([224, 224]), 35 | transforms.ToTensor(), 36 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])} 37 | 38 | train_dataset = MyDataSet(images_path=train_images_path, 39 | images_class=train_images_label, 40 | audio_path=train_audio_path, 41 | audio_class=train_audio_label, 42 | transform=data_transform["train"]) 43 | 44 | val_dataset = MyDataSet(images_path=val_images_path, 45 | images_class=val_images_label, 46 | audio_path=val_audio_path, 47 | audio_class=val_audio_label, 48 | transform=data_transform["val"]) 49 | 50 | batch_size = args.batch_size 51 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 52 | print('Using {} dataloader workers every process'.format(nw)) 53 | train_loader = torch.utils.data.DataLoader(train_dataset, 54 | batch_size=batch_size, 55 | shuffle=True, 56 | pin_memory=True, 57 | num_workers=nw, 58 | collate_fn=train_dataset.collate_fn) 59 | 60 | val_loader = torch.utils.data.DataLoader(val_dataset, 61 | batch_size=batch_size, 62 | shuffle=False, 63 | pin_memory=True, 64 | num_workers=nw, 65 | collate_fn=val_dataset.collate_fn) 66 | 67 | model = AVoiD_mm(args, num_classes=45, has_logits=False).to(device) 68 | 69 | if args.weights != "": 70 | assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights) 71 | weights_dict = torch.load(args.weights, map_location=device) 72 | del_keys = ['head.weight', 'head.bias'] if model.has_logits \ 73 | else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias'] 74 | for k in del_keys: 75 | del weights_dict[k] 76 | print(model.load_state_dict(weights_dict, strict=False)) 77 | 78 | if args.freeze_layers: 79 | for name, para in model.named_parameters(): 80 | if "head" not in name and "pre_logits" not in name: 81 | para.requires_grad_(False) 82 | else: 83 | print("training {}".format(name)) 84 | 85 | pg = [p for p in model.parameters() if p.requires_grad] 86 | optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5) 87 | lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine 88 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) 89 | loss_weight = torch.nn.Parameter(torch.ones(1)).to(device) 90 | 91 | for epoch in range(args.epochs): 92 | # train 93 | train_loss, train_acc = train_one_epoch(args,model=model, 94 | optimizer=optimizer, 95 | data_loader=train_loader, 96 | device=device, 97 | epoch=epoch, 98 | loss_weight=loss_weight) 99 | scheduler.step() 100 | val_loss, val_acc = evaluate(args,model=model, 101 | data_loader=val_loader, 102 | device=device, 103 | epoch=epoch, 104 | loss_weight=loss_weight) 105 | 106 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 107 | 108 | tb_writer.add_scalar(tags[0], train_loss, epoch) 109 | tb_writer.add_scalar(tags[1], train_acc, epoch) 110 | tb_writer.add_scalar(tags[2], val_loss, epoch) 111 | tb_writer.add_scalar(tags[3], val_acc, epoch) 112 | tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) 113 | 114 | tb_writer.close() 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('--num_classes', type=int, default=2) 120 | parser.add_argument('--epochs', type=int, default=20000) 121 | parser.add_argument('--batch-size', type=int, default=512) 122 | parser.add_argument('--lr', type=float, default=0.01) 123 | parser.add_argument('--lrf', type=float, default=0.01) 124 | #parser.add_argument('--bce-only', dest='bce_only', help='train with only binary cross entropy loss', 125 | #action='store_true') 126 | # video path 127 | parser.add_argument('--video_data-path', type=str, 128 | default="./video") 129 | # audio path 130 | parser.add_argument('--audio_data-path', type=str, 131 | default="./audio") 132 | parser.add_argument('--model-name', default='', help='create model name') 133 | parser.add_argument('--weights', type=str, default='', 134 | help='initial weights path') 135 | parser.add_argument('--freeze-layers', type=bool, default=True) 136 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 137 | 138 | opt = parser.parse_args() 139 | main(opt) 140 | -------------------------------------------------------------------------------- /model/compare/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | from tensorboardX import SummaryWriter 9 | from torchvision import transforms 10 | 11 | from my_dataset import MyDataSet 12 | from model import vit_base_patch16_224_in21k as create_model 13 | from cast_model import castModel 14 | from utils import read_split_data, train_one_epoch, evaluate 15 | 16 | 17 | def main(args): 18 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 19 | # random.seed(0) 20 | 21 | if os.path.exists("./weights") is False: 22 | os.makedirs("./weights") 23 | 24 | tb_writer = SummaryWriter() 25 | 26 | train_images_path, train_images_label, train_audio_path, train_audio_label, val_images_path, val_images_label, val_audio_path, val_audio_label = read_split_data( 27 | args.faces_data_path, args.audio_data_path) 28 | 29 | data_transform = { 30 | "train": transforms.Compose([transforms.RandomResizedCrop(224), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 34 | "val": transforms.Compose([transforms.Resize(256), 35 | transforms.CenterCrop(224), 36 | transforms.ToTensor(), 37 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])} 38 | 39 | 40 | train_dataset = MyDataSet(images_path=train_images_path, 41 | images_class=train_images_label, 42 | audio_path=train_audio_path, 43 | audio_class=train_audio_label, 44 | transform=data_transform["train"]) 45 | 46 | val_dataset = MyDataSet(images_path=val_images_path, 47 | images_class=val_images_label, 48 | audio_path=val_audio_path, 49 | audio_class=val_audio_label, 50 | transform=data_transform["val"]) 51 | 52 | batch_size = args.batch_size 53 | nw = 0 # min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 54 | print('Using {} dataloader workers every process'.format(nw)) 55 | train_loader = torch.utils.data.DataLoader(train_dataset, 56 | batch_size=batch_size, 57 | shuffle=True, 58 | pin_memory=True, 59 | num_workers=nw, 60 | collate_fn=train_dataset.collate_fn) 61 | 62 | val_loader = torch.utils.data.DataLoader(val_dataset, 63 | batch_size=batch_size, 64 | shuffle=False, 65 | pin_memory=True, 66 | num_workers=nw, 67 | collate_fn=val_dataset.collate_fn) 68 | 69 | 70 | castmodel = castModel().to(device) 71 | model = create_model(num_classes=9, has_logits=False).to(device) 72 | 73 | if args.weights != "": 74 | assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights) 75 | weights_dict = torch.load(args.weights, map_location=device) 76 | del_keys = ['head.weight', 'head.bias'] if model.has_logits \ 77 | else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias'] 78 | for k in del_keys: 79 | del weights_dict[k] 80 | print(model.load_state_dict(weights_dict, strict=False)) 81 | 82 | if args.freeze_layers: 83 | for name, para in model.named_parameters(): 84 | if "head" not in name and "pre_logits" not in name: 85 | para.requires_grad_(False) 86 | else: 87 | print("training {}".format(name)) 88 | 89 | pg = [p for p in model.parameters() if p.requires_grad] 90 | cast_pg = [p for p in castmodel.parameters() if p.requires_grad] 91 | optimizer1 = optim.SGD(cast_pg, lr=args.lr, momentum=0.9, weight_decay=5E-5) 92 | optimizer2 = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5) 93 | lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine 94 | scheduler1 = lr_scheduler.LambdaLR(optimizer1, lr_lambda=lf) 95 | scheduler2 = lr_scheduler.LambdaLR(optimizer2, lr_lambda=lf) 96 | 97 | for epoch in range(args.epochs): 98 | # train 99 | train_loss, train_acc = train_one_epoch(castModel=castmodel, 100 | model=model, 101 | optimizer1=optimizer1, 102 | optimizer2=optimizer2, 103 | data_loader=train_loader, 104 | device=device, 105 | epoch=epoch) 106 | 107 | scheduler1.step() 108 | scheduler2.step() 109 | 110 | # validate 111 | val_loss, val_acc = evaluate(castModel=castmodel, 112 | model=model, 113 | data_loader=val_loader, 114 | device=device, 115 | epoch=epoch) 116 | 117 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 118 | tb_writer.add_scalar(tags[0], train_loss, epoch) 119 | tb_writer.add_scalar(tags[1], train_acc, epoch) 120 | tb_writer.add_scalar(tags[2], val_loss, epoch) 121 | tb_writer.add_scalar(tags[3], val_acc, epoch) 122 | tb_writer.add_scalar(tags[4], optimizer1.param_groups[0]["lr"], epoch) 123 | 124 | tb_writer.close() 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--num_classes', type=int, default=9) 130 | parser.add_argument('--epochs', type=int, default=5) 131 | parser.add_argument('--batch-size', type=int, default=4) 132 | parser.add_argument('--lr', type=float, default=0.001) 133 | parser.add_argument('--lrf', type=float, default=0.01) 134 | parser.add_argument('--bce-only', dest='bce_only', help='train with only binary cross entropy loss', 135 | action='store_true') 136 | 137 | parser.add_argument('--faces_data-path', type=str, 138 | default="./faces") 139 | parser.add_argument('--audio_data-path', type=str, 140 | default="./audio") 141 | parser.add_argument('--model-name', default='', help='create model name') 142 | 143 | parser.add_argument('--weights', type=str, default='', 144 | help='initial weights path') 145 | parser.add_argument('--freeze-layers', type=bool, default=True) 146 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 147 | 148 | opt = parser.parse_args() 149 | 150 | main(opt) 151 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import random 6 | 7 | import torch 8 | from tqdm import tqdm 9 | import glob 10 | import re 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | def read_split_data(faces_root: str, audio_root: str, val_rate: float = 0.2): 16 | random.seed(0) 17 | 18 | assert os.path.exists(faces_root), "dataset root: {} does not exist.".format(faces_root) 19 | assert os.path.exists(audio_root), "dataset root: {} does not exist.".format(audio_root) 20 | 21 | 22 | face_class = [cla for cla in os.listdir(faces_root) if 23 | os.path.isdir(os.path.join(faces_root, cla))] # [fcmr0,fcrh0,fdac1,fdms0,fdrd1] 24 | audio_class = [cla for cla in os.listdir(audio_root) if 25 | os.path.isdir(os.path.join(audio_root, cla))] # [fcmr0,fcrh0,fdac1,fdms0,fdrd1] 26 | 27 | face_class.sort() 28 | audio_class.sort() 29 | 30 | class_indices = dict((k, v) for v, k in enumerate(face_class)) 31 | 32 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) 33 | 34 | with open('class_indices.json', 'w') as json_file: 35 | json_file.write(json_str) 36 | 37 | train_images_path = [] 38 | train_images_label = [] 39 | train_audio_path = [] 40 | train_audio_label = [] 41 | val_images_path = [] 42 | val_images_label = [] 43 | val_audio_path = [] 44 | val_audio_label = [] 45 | 46 | supported = [".jpg", ".JPG", ".png", ".PNG"] 47 | for cla in face_class: 48 | faces_root2 = os.path.join(faces_root, cla) # faces/fcmr0 49 | audio_root2 = os.path.join(audio_root, cla) # audio/fcrm0 50 | faces_cla2_class = [cla2 for cla2 in os.listdir(faces_root2) if 51 | os.path.isdir(os.path.join(faces_root2 + '/', cla2))] # [sa1,sa2,...] 52 | for cla2 in faces_cla2_class: 53 | faces_cla_path = os.path.join(faces_root2 + '/', cla2) # faces/fcmr0/sa1 54 | faces_images = [os.path.join(faces_cla_path + '/', i) for i in os.listdir(faces_cla_path) 55 | if os.path.splitext(i)[-1] in supported] 56 | audio_images = glob.glob('{}/{}**'.format(audio_root2, cla2)) 57 | for i in audio_images: 58 | if 'chunk' not in i: 59 | audio_images.remove(i) 60 | 61 | audio_images = sorted(audio_images, key=lambda i: int(re.findall(r'\d+', i)[-1])) 62 | 63 | fl = len(faces_images) 64 | al = len(audio_images) 65 | if fl > al: 66 | faces_images = faces_images[:al] 67 | elif fl < al: 68 | audio_images = audio_images[:fl] 69 | 70 | image_class = class_indices[cla] 71 | 72 | val_f_path = random.sample(faces_images, k=int(len(faces_images) * val_rate)) 73 | # val_a_path = random.sample(audio_images, k=int(len(audio_images) * val_rate)) 74 | val_a_path = [] 75 | for i in val_f_path: 76 | index = faces_images.index(i) 77 | val_a_path.append(audio_images[index]) 78 | 79 | for img_path in faces_images: 80 | if img_path in val_f_path: 81 | val_images_path.append(img_path) 82 | val_images_label.append(image_class) 83 | else: 84 | train_images_path.append(img_path) 85 | train_images_label.append(image_class) 86 | 87 | for img_path in audio_images: 88 | if img_path in val_a_path: 89 | val_audio_path.append(img_path) 90 | val_audio_label.append(image_class) 91 | else: 92 | train_audio_path.append(img_path) 93 | train_audio_label.append(image_class) 94 | 95 | print("{} images were found in the dataset.".format( 96 | len(train_images_path) + len(val_images_path) + len(train_audio_path) + len(val_audio_path))) 97 | print("{} images for training.".format(len(train_images_path) + len(train_audio_path))) 98 | print("{} images for validation.".format(len(val_images_path) + len(val_audio_path))) 99 | return train_images_path, train_images_label, train_audio_path, train_audio_label, val_images_path, val_images_label, val_audio_path, val_audio_label 100 | 101 | 102 | def plot_data_loader_image(data_loader): 103 | batch_size = data_loader.batch_size 104 | plot_num = min(batch_size, 4) 105 | 106 | json_path = '' 107 | assert os.path.exists(json_path), json_path + " does not exist." 108 | json_file = open(json_path, 'r') 109 | class_indices = json.load(json_file) 110 | 111 | for data in data_loader: 112 | images, labels = data 113 | for i in range(plot_num): 114 | # [C, H, W] -> [H, W, C] 115 | img = images[i].numpy().transpose(1, 2, 0) 116 | img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 117 | label = labels[i].item() 118 | plt.subplot(1, plot_num, i + 1) 119 | plt.xlabel(class_indices[str(label)]) 120 | plt.xticks([]) 121 | plt.yticks([]) 122 | plt.imshow(img.astype('uint8')) 123 | plt.show() 124 | 125 | 126 | def write_pickle(list_info: list, file_name: str): 127 | with open(file_name, 'wb') as f: 128 | pickle.dump(list_info, f) 129 | 130 | 131 | def read_pickle(file_name: str) -> list: 132 | with open(file_name, 'rb') as f: 133 | info_list = pickle.load(f) 134 | return info_list 135 | 136 | 137 | def train_one_epoch(castModel, model, optimizer1, optimizer2, data_loader, device, epoch): 138 | model.train() 139 | castModel.train() 140 | loss_function = torch.nn.CrossEntropyLoss() 141 | accu_loss = torch.zeros(1).to(device) 142 | accu_num = torch.zeros(1).to(device) 143 | 144 | sample_num = 0 145 | data_loader = tqdm(data_loader) 146 | for step, images_data in enumerate(data_loader): 147 | images, images_labels, audio, audio_labels = images_data 148 | 149 | sample_num = sample_num + images.shape[0] 150 | 151 | for i in range(2): 152 | optimizer1.zero_grad() 153 | castLoss, pred = castModel(images.to(device), audio.to(device)) 154 | castLoss.backward(retain_graph=True) 155 | optimizer1.step() 156 | 157 | 158 | pred1 = model(pred.detach()) 159 | pred_classes = torch.max(pred1, dim=1)[1] 160 | accu_num = accu_num + torch.eq(pred_classes, images_labels.to(device)).sum() # ?????? 161 | 162 | optimizer2.zero_grad() 163 | loss = loss_function(pred1, images_labels.to(device)) 164 | loss.backward() 165 | accu_loss = accu_loss + loss.detach() 166 | optimizer2.step() 167 | 168 | data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 169 | accu_loss.item() / (step + 1), 170 | accu_num.item() / sample_num) 171 | 172 | if not torch.isfinite(loss): 173 | print('WARNING: non-finite loss, ending training ', loss) 174 | sys.exit(1) 175 | 176 | 177 | 178 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 179 | 180 | 181 | @torch.no_grad() 182 | def evaluate(castModel, model, data_loader, device, epoch): 183 | loss_function = torch.nn.CrossEntropyLoss() 184 | 185 | model.eval() 186 | castModel.eval() 187 | 188 | accu_num = torch.zeros(1).to(device) 189 | accu_loss = torch.zeros(1).to(device) 190 | 191 | sample_num = 0 192 | data_loader = tqdm(data_loader) 193 | for step, images_data in enumerate(data_loader): 194 | images, images_labels, audio, audio_labels = images_data 195 | sample_num = sample_num + images.shape[0] 196 | 197 | _, pred = castModel(images.to(device), audio.to(device)) 198 | pred1 = model(pred) 199 | pred_classes = torch.max(pred1, dim=1)[1] 200 | accu_num = accu_num + torch.eq(pred_classes, images_labels.to(device)).sum() 201 | 202 | loss = loss_function(pred1, images_labels.to(device)) 203 | accu_loss = accu_loss + loss 204 | data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 205 | accu_loss.item() / (step + 1), 206 | accu_num.item() / sample_num) 207 | 208 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 209 | -------------------------------------------------------------------------------- /model/AVoiD.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from collections import OrderedDict 4 | import torch 5 | import torch.nn as nn 6 | from MMD import MMD, PatchEmbed, Encoder_layer 7 | 8 | class AVoiD(nn.Module): 9 | def __init__(self, args, img_size=224, patch_size=16, in_c=3, num_classes=1000, 10 | embed_dim=768, depth=5, num_heads=12, mlp_ratio=4.0, qkv_bias=True, 11 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., 12 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, 13 | act_layer=None, device='cuda:0'): 14 | 15 | super(AVoiD, self).__init__() 16 | self.dim = embed_dim 17 | self.num_heads = num_heads 18 | self.mlp_ratio = mlp_ratio 19 | self.qkv_bias = qkv_bias 20 | self.qk_scale = qk_scale 21 | self.drop_ratio = drop_ratio 22 | self.attn_drop_ratio = attn_drop_ratio 23 | self.act_layer = act_layer 24 | self.depth = depth 25 | self.num_classes = num_classes 26 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 27 | self.num_tokens = 2 if distilled else 1 28 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 29 | act_layer = act_layer or nn.GELU 30 | 31 | self.patch_embed_video = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, 32 | embed_dim=embed_dim) 33 | self.patch_embed_audio = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, 34 | embed_dim=embed_dim) 35 | num_patches = self.patch_embed_video.num_patches 36 | 37 | self.cls_token_video = nn.Parameter(torch.zeros(1, 1, 768)) 38 | self.cls_token_audio = nn.Parameter(torch.zeros(1, 1, 768)) 39 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 40 | # position 41 | self.pos_embed_video = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 42 | self.pos_embed_audio = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 43 | self.pos_drop_video = nn.Dropout(p=drop_ratio) 44 | self.pos_drop_audio = nn.Dropout(p=drop_ratio) 45 | # time 46 | self.time_embed_video = nn.Parameter(torch.zeros(1, embed_dim)) 47 | self.time_embed_audio = nn.Parameter(torch.zeros(1, embed_dim)) 48 | self.time_drop_video = nn.Dropout(p=drop_ratio) 49 | self.time_drop_audio = nn.Dropout(p=drop_ratio) 50 | 51 | self.block = nn.ModuleList() 52 | for _ in range(depth - 1): 53 | layer = MMD(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 54 | qk_scale=qk_scale, 55 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, act_layer=act_layer) 56 | self.block.append(copy.deepcopy(layer)) 57 | 58 | self.last_block = Encoder_layer(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 59 | qk_scale=qk_scale, 60 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio) 61 | 62 | # self.part_select = Search() 63 | 64 | self.video_encoder = nn.Sequential(*[ 65 | Encoder_layer(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 66 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio) 67 | for _ in range(6) 68 | ]) 69 | 70 | self.audio_encoder = nn.Sequential(*[ 71 | Encoder_layer(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 72 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio) 73 | for _ in range(6) 74 | ]) 75 | self.norm = norm_layer(embed_dim) 76 | self.av_fc = nn.Linear(embed_dim * 2, embed_dim) 77 | self.fc = nn.Linear(embed_dim * 3, embed_dim) 78 | # Select 79 | # self.Select = Select(bs=args.batch_size, device=args.device, embed_dim=embed_dim, 80 | # num_heads=num_heads, 81 | # mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 82 | # drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio) 83 | 84 | # if representation_size and not distilled: 85 | # self.has_logits = True 86 | # self.num_features = representation_size 87 | # self.pre_logits = nn.Sequential(OrderedDict([ 88 | # ("fc", nn.Linear(embed_dim, representation_size)), 89 | # ("act", nn.Tanh()) 90 | # ])) 91 | # else: 92 | # self.has_logits = False 93 | # self.pre_logits = nn.Identity() 94 | 95 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 96 | # self.head_dist = None 97 | # if distilled: 98 | # self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 99 | 100 | # Weight init 101 | nn.init.trunc_normal_(self.pos_embed_video, std=0.02) 102 | nn.init.trunc_normal_(self.pos_embed_audio, std=0.02) 103 | if self.dist_token is not None: 104 | nn.init.trunc_normal_(self.dist_token, std=0.02) 105 | 106 | nn.init.trunc_normal_(self.cls_token_video, std=0.02) 107 | nn.init.trunc_normal_(self.cls_token_audio, std=0.02) 108 | self.apply(_init_vit_weights) 109 | self.device = args.device 110 | 111 | self.w1 = torch.nn.Parameter(torch.ones(1)).to(device) 112 | self.w2 = torch.nn.Parameter(torch.ones(1)).to(device) 113 | self.w3 = torch.nn.Parameter(torch.ones(1)).to(device) 114 | 115 | def forward_features(self, video, audio): 116 | x = self.patch_embed_video(video) 117 | y = self.patch_embed_audio(audio) 118 | weight_list_v = [] 119 | weight_list_a = [] 120 | cls_token_video = self.cls_token_video.expand(x.shape[0], -1, -1) 121 | cls_token_audio = self.cls_token_audio.expand(x.shape[0], -1, -1) 122 | if self.dist_token is None: 123 | x = torch.cat((cls_token_video, x), dim=1) 124 | y = torch.cat((cls_token_audio, y), dim=1) 125 | else: 126 | x = torch.cat((cls_token_video, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 127 | y = torch.cat((cls_token_audio, self.dist_token.expand(y.shape[0], -1, -1), y), dim=1) 128 | # position embed 129 | x = self.pos_drop_video(x + self.pos_embed_video) 130 | y = self.pos_drop_audio(y + self.pos_embed_audio) 131 | # time embed 132 | x = self.time_drop_video(x + self.time_embed_video) 133 | y = self.time_drop_audio(y + self.time_embed_video) 134 | 135 | x = self.video_encoder(x) 136 | y = self.audio_encoder(y) 137 | 138 | Encoder_video = x 139 | Encoder_audio = y 140 | 141 | # cls_v,cls_v 142 | cls_v = x[:, 0, :] 143 | cls_a = y[:, 0, :] 144 | 145 | num_heads = self.Select(cls_v, cls_a) 146 | 147 | block = nn.ModuleList() 148 | for _ in range(self.depth - 1): 149 | layer = MMD(dim=self.embed_dim, num_heads=num_heads, mlp_ratio=self.mlp_ratio, 150 | qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, 151 | drop_ratio=self.drop_ratio, attn_drop_ratio=self.attn_drop_ratio) 152 | block.append(copy.deepcopy(layer)) 153 | 154 | block.to(self.device) 155 | 156 | for b in block: 157 | x, y, w_v, w_a = b((x, y, Encoder_video, Encoder_audio)) # w:[bs,num_heads,hidden_size,hidden_size] 158 | weight_list_v.append(w_v) 159 | weight_list_a.append(w_a) 160 | 161 | xy = self.av_fc(torch.cat((x, y), dim=-1)) 162 | 163 | part_num_va, part_inx_va = self.part_select(weight_list_v, weight_list_a) 164 | part_inx_va = part_inx_va + 1 165 | parts_va = [] 166 | B, num = part_inx_va.shape 167 | for i in range(B): 168 | parts_va.append(xy[i, part_inx_va[i, :]]) # hidden_states[i, part_inx[i,:]]:[B,num_heads] 169 | parts_va = torch.stack(parts_va).squeeze(1) 170 | concat_va = torch.cat((xy[:, 0].unsqueeze(1), parts_va), dim=1) 171 | x = self.last_block(concat_va) 172 | fusion_cls = x[:, 0] 173 | last_cls = self.fc(torch.cat((self.w1 * cls_v, self.w2 * fusion_cls, self.w3 * cls_a), -1)) 174 | return last_cls, cls_v, cls_v 175 | 176 | def forward(self, x, y): 177 | x, cls_v, cls_v = self.forward_features(x, y) 178 | feats = x 179 | if self.head_dist is not None: 180 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) 181 | if self.training and not torch.jit.is_scripting(): 182 | # during inference, return the average of both classifier predictions 183 | return x, x_dist 184 | else: 185 | return (x + x_dist) / 2 186 | else: 187 | x = self.head(x) 188 | return x, feats, cls_v, cls_v 189 | 190 | 191 | def _init_vit_weights(m): 192 | """ 193 | ViT weight initialization 194 | :param m: module 195 | """ 196 | if isinstance(m, nn.Linear): 197 | nn.init.trunc_normal_(m.weight, std=.01) 198 | if m.bias is not None: 199 | nn.init.zeros_(m.bias) 200 | elif isinstance(m, nn.Conv2d): 201 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 202 | if m.bias is not None: 203 | nn.init.zeros_(m.bias) 204 | elif isinstance(m, nn.LayerNorm): 205 | nn.init.zeros_(m.bias) 206 | nn.init.ones_(m.weight) 207 | 208 | 209 | def AVoiD_mm(args, num_classes: int = 21843, has_logits: bool = True): 210 | model = AVoiD(args=args, 211 | img_size=224, 212 | patch_size=16, 213 | embed_dim=768, 214 | depth=6, 215 | num_heads=12, 216 | representation_size=768 if has_logits else None, 217 | num_classes=num_classes) 218 | return model 219 | -------------------------------------------------------------------------------- /data_processing/transfer.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numbers 4 | import math 5 | import collections 6 | import numpy as np 7 | from PIL import ImageOps, Image 8 | from joblib import Parallel, delayed 9 | 10 | import torchvision 11 | from torchvision import transforms 12 | import torchvision.transforms.functional as F 13 | 14 | class Padding: 15 | def __init__(self, pad): 16 | self.pad = pad 17 | 18 | def __call__(self, img): 19 | return ImageOps.expand(img, border=self.pad, fill=0) 20 | 21 | class Scale: 22 | def __init__(self, size, interpolation=Image.NEAREST): 23 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 24 | self.size = size 25 | self.interpolation = interpolation 26 | 27 | def __call__(self, imgmap): 28 | # assert len(imgmap) > 1 # list of images 29 | img1 = imgmap[0] 30 | if isinstance(self.size, int): 31 | w, h = img1.size 32 | if (w <= h and w == self.size) or (h <= w and h == self.size): 33 | return imgmap 34 | if w < h: 35 | ow = self.size 36 | oh = int(self.size * h / w) 37 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 38 | else: 39 | oh = self.size 40 | ow = int(self.size * w / h) 41 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 42 | else: 43 | return [i.resize(self.size, self.interpolation) for i in imgmap] 44 | 45 | 46 | class CenterCrop: 47 | def __init__(self, size, consistent=True): 48 | if isinstance(size, numbers.Number): 49 | self.size = (int(size), int(size)) 50 | else: 51 | self.size = size 52 | 53 | def __call__(self, imgmap): 54 | img1 = imgmap[0] 55 | w, h = img1.size 56 | th, tw = self.size 57 | x1 = int(round((w - tw) / 2.)) 58 | y1 = int(round((h - th) / 2.)) 59 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 60 | 61 | 62 | class RandomCropWithProb: 63 | def __init__(self, size, p=0.8, consistent=True): 64 | if isinstance(size, numbers.Number): 65 | self.size = (int(size), int(size)) 66 | else: 67 | self.size = size 68 | self.consistent = consistent 69 | self.threshold = p 70 | 71 | def __call__(self, imgmap): 72 | img1 = imgmap[0] 73 | w, h = img1.size 74 | if self.size is not None: 75 | th, tw = self.size 76 | if w == tw and h == th: 77 | return imgmap 78 | if self.consistent: 79 | if random.random() < self.threshold: 80 | x1 = random.randint(0, w - tw) 81 | y1 = random.randint(0, h - th) 82 | else: 83 | x1 = int(round((w - tw) / 2.)) 84 | y1 = int(round((h - th) / 2.)) 85 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 86 | else: 87 | result = [] 88 | for i in imgmap: 89 | if random.random() < self.threshold: 90 | x1 = random.randint(0, w - tw) 91 | y1 = random.randint(0, h - th) 92 | else: 93 | x1 = int(round((w - tw) / 2.)) 94 | y1 = int(round((h - th) / 2.)) 95 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 96 | return result 97 | else: 98 | return imgmap 99 | 100 | class RandomCrop: 101 | def __init__(self, size, consistent=True): 102 | if isinstance(size, numbers.Number): 103 | self.size = (int(size), int(size)) 104 | else: 105 | self.size = size 106 | self.consistent = consistent 107 | 108 | def __call__(self, imgmap, flowmap=None): 109 | img1 = imgmap[0] 110 | w, h = img1.size 111 | if self.size is not None: 112 | th, tw = self.size 113 | if w == tw and h == th: 114 | return imgmap 115 | if not flowmap: 116 | if self.consistent: 117 | x1 = random.randint(0, w - tw) 118 | y1 = random.randint(0, h - th) 119 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 120 | else: 121 | result = [] 122 | for i in imgmap: 123 | x1 = random.randint(0, w - tw) 124 | y1 = random.randint(0, h - th) 125 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 126 | return result 127 | elif flowmap is not None: 128 | assert (not self.consistent) 129 | result = [] 130 | for idx, i in enumerate(imgmap): 131 | proposal = [] 132 | for j in range(3): # number of proposal: use the one with largest optical flow 133 | x = random.randint(0, w - tw) 134 | y = random.randint(0, h - th) 135 | proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))]) 136 | [x1, y1, _] = max(proposal, key=lambda x: x[-1]) 137 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 138 | return result 139 | else: 140 | raise ValueError('wrong case') 141 | else: 142 | return imgmap 143 | 144 | 145 | class RandomSizedCrop: 146 | def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): 147 | self.size = size 148 | self.interpolation = interpolation 149 | self.consistent = consistent 150 | self.threshold = p 151 | 152 | def __call__(self, imgmap): 153 | img1 = imgmap[0] 154 | if random.random() < self.threshold: # do RandomSizedCrop 155 | for attempt in range(10): 156 | area = img1.size[0] * img1.size[1] 157 | target_area = random.uniform(0.5, 1) * area 158 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 159 | 160 | w = int(round(math.sqrt(target_area * aspect_ratio))) 161 | h = int(round(math.sqrt(target_area / aspect_ratio))) 162 | 163 | if self.consistent: 164 | if random.random() < 0.5: 165 | w, h = h, w 166 | if w <= img1.size[0] and h <= img1.size[1]: 167 | x1 = random.randint(0, img1.size[0] - w) 168 | y1 = random.randint(0, img1.size[1] - h) 169 | 170 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] 171 | for i in imgmap: assert(i.size == (w, h)) 172 | 173 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] 174 | else: 175 | result = [] 176 | for i in imgmap: 177 | if random.random() < 0.5: 178 | w, h = h, w 179 | if w <= img1.size[0] and h <= img1.size[1]: 180 | x1 = random.randint(0, img1.size[0] - w) 181 | y1 = random.randint(0, img1.size[1] - h) 182 | result.append(i.crop((x1, y1, x1 + w, y1 + h))) 183 | assert(result[-1].size == (w, h)) 184 | else: 185 | result.append(i) 186 | 187 | assert len(result) == len(imgmap) 188 | return [i.resize((self.size, self.size), self.interpolation) for i in result] 189 | 190 | # Fallback 191 | scale = Scale(self.size, interpolation=self.interpolation) 192 | crop = CenterCrop(self.size) 193 | return crop(scale(imgmap)) 194 | else: # don't do RandomSizedCrop, do CenterCrop 195 | crop = CenterCrop(self.size) 196 | return crop(imgmap) 197 | 198 | 199 | class RandomHorizontalFlip: 200 | def __init__(self, consistent=True, command=None): 201 | self.consistent = consistent 202 | if command == 'left': 203 | self.threshold = 0 204 | elif command == 'right': 205 | self.threshold = 1 206 | else: 207 | self.threshold = 0.5 208 | def __call__(self, imgmap): 209 | if self.consistent: 210 | if random.random() < self.threshold: 211 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] 212 | else: 213 | return imgmap 214 | else: 215 | result = [] 216 | for i in imgmap: 217 | if random.random() < self.threshold: 218 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) 219 | else: 220 | result.append(i) 221 | assert len(result) == len(imgmap) 222 | return result 223 | 224 | 225 | class RandomGray: 226 | '''Actually it is a channel splitting, not strictly grayscale images''' 227 | def __init__(self, consistent=True, p=0.5): 228 | self.consistent = consistent 229 | self.p = p # probability to apply grayscale 230 | def __call__(self, imgmap): 231 | if self.consistent: 232 | if random.random() < self.p: 233 | return [self.grayscale(i) for i in imgmap] 234 | else: 235 | return imgmap 236 | else: 237 | result = [] 238 | for i in imgmap: 239 | if random.random() < self.p: 240 | result.append(self.grayscale(i)) 241 | else: 242 | result.append(i) 243 | assert len(result) == len(imgmap) 244 | return result 245 | 246 | def grayscale(self, img): 247 | channel = np.random.choice(3) 248 | np_img = np.array(img)[:,:,channel] 249 | np_img = np.dstack([np_img, np_img, np_img]) 250 | img = Image.fromarray(np_img, 'RGB') 251 | return img 252 | 253 | 254 | class ColorJitter(object): 255 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): 256 | self.brightness = self._check_input(brightness, 'brightness') 257 | self.contrast = self._check_input(contrast, 'contrast') 258 | self.saturation = self._check_input(saturation, 'saturation') 259 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 260 | clip_first_on_zero=False) 261 | self.consistent = consistent 262 | self.threshold = p 263 | 264 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 265 | if isinstance(value, numbers.Number): 266 | if value < 0: 267 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 268 | value = [center - value, center + value] 269 | if clip_first_on_zero: 270 | value[0] = max(value[0], 0) 271 | elif isinstance(value, (tuple, list)) and len(value) == 2: 272 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 273 | raise ValueError("{} values should be between {}".format(name, bound)) 274 | else: 275 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 276 | 277 | # or (0., 0.) for hue, do nothing 278 | if value[0] == value[1] == center: 279 | value = None 280 | return value 281 | 282 | @staticmethod 283 | def get_params(brightness, contrast, saturation, hue): 284 | transforms = [] 285 | 286 | if brightness is not None: 287 | brightness_factor = random.uniform(brightness[0], brightness[1]) 288 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 289 | 290 | if contrast is not None: 291 | contrast_factor = random.uniform(contrast[0], contrast[1]) 292 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 293 | 294 | if saturation is not None: 295 | saturation_factor = random.uniform(saturation[0], saturation[1]) 296 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 297 | 298 | if hue is not None: 299 | hue_factor = random.uniform(hue[0], hue[1]) 300 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) 301 | 302 | random.shuffle(transforms) 303 | transform = torchvision.transforms.Compose(transforms) 304 | 305 | return transform 306 | 307 | def __call__(self, imgmap): 308 | if random.random() < self.threshold: # do ColorJitter 309 | if self.consistent: 310 | transform = self.get_params(self.brightness, self.contrast, 311 | self.saturation, self.hue) 312 | return [transform(i) for i in imgmap] 313 | else: 314 | result = [] 315 | for img in imgmap: 316 | transform = self.get_params(self.brightness, self.contrast, 317 | self.saturation, self.hue) 318 | result.append(transform(img)) 319 | return result 320 | else: # don't do ColorJitter, do nothing 321 | return imgmap 322 | 323 | def __repr__(self): 324 | format_string = self.__class__.__name__ + '(' 325 | format_string += 'brightness={0}'.format(self.brightness) 326 | format_string += ', contrast={0}'.format(self.contrast) 327 | format_string += ', saturation={0}'.format(self.saturation) 328 | format_string += ', hue={0})'.format(self.hue) 329 | return format_string 330 | 331 | 332 | class RandomRotation: 333 | def __init__(self, consistent=True, degree=15, p=1.0): 334 | self.consistent = consistent 335 | self.degree = degree 336 | self.threshold = p 337 | def __call__(self, imgmap): 338 | if random.random() < self.threshold: # do RandomRotation 339 | if self.consistent: 340 | deg = np.random.randint(-self.degree, self.degree, 1)[0] 341 | return [i.rotate(deg, expand=True) for i in imgmap] 342 | else: 343 | return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] 344 | else: 345 | return imgmap 346 | 347 | class ToTensor: 348 | def __call__(self, imgmap): 349 | totensor = transforms.ToTensor() 350 | return [totensor(i) for i in imgmap] 351 | 352 | class Normalize: 353 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 354 | self.mean = mean 355 | self.std = std 356 | def __call__(self, imgmap): 357 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 358 | return [normalize(i) for i in imgmap] 359 | 360 | 361 | -------------------------------------------------------------------------------- /model/compare/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | original code from rwightman: 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 4 | """ 5 | from functools import partial 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def drop_path(x, drop_prob: float = 0., training: bool = False): 13 | """ 14 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 15 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 16 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 17 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 18 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 19 | 'survival rate' as the argument. 20 | """ 21 | if drop_prob == 0. or not training: 22 | return x 23 | keep_prob = 1 - drop_prob 24 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 25 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 26 | random_tensor.floor_() # binarize 27 | output = x.div(keep_prob) * random_tensor 28 | return output 29 | 30 | 31 | class DropPath(nn.Module): 32 | """ 33 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 34 | """ 35 | 36 | def __init__(self, drop_prob=None): 37 | super(DropPath, self).__init__() 38 | self.drop_prob = drop_prob 39 | 40 | def forward(self, x): 41 | return drop_path(x, self.drop_prob, self.training) 42 | 43 | 44 | ''' 45 | 46 | 47 | ''' 48 | class PatchEmbed(nn.Module): 49 | """ 50 | 2D Image to Patch Embedding 51 | """ 52 | 53 | def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): 54 | super().__init__() 55 | img_size = (img_size, img_size) 56 | patch_size = (patch_size, patch_size) 57 | self.img_size = img_size 58 | self.patch_size = patch_size 59 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 60 | self.num_patches = self.grid_size[0] * self.grid_size[1] 61 | 62 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 63 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 64 | self.fc = nn.Linear(1536, 768) # 每一个向量对应拼接 65 | 66 | def forward(self, x, y): 67 | B, C, H, W = x.shape 68 | assert H == self.img_size[0] and W == self.img_size[1], \ 69 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 70 | 71 | # flatten: [B, C, H, W] -> [B, C, HW] 72 | ''' 73 | # 在通道的维度上进行拼接 正确率很低 74 | x = self.proj(x).flatten(2).transpose(1, 2) # [B,196,768] 75 | y = self.proj(y).flatten(2).transpose(1, 2) # [B,196,768] 76 | x = torch.cat((x, y), dim=2) # [B,196,1536] 77 | x = self.fc(x) # [B,196,1536] 78 | x=self.norm(x) 79 | ''' 80 | 81 | ''' 82 | # 在特征值上拼接,然后经过线性层 83 | x = self.proj(x).flatten(2) # [B,768,196] 84 | y = self.proj(y).flatten(2) 85 | x = torch.cat((x, y), dim=2) # [B,768,392] 86 | x = self.fc(x) # [B,768,196] 87 | x = x.transpose(1, 2) # [B,768,196] 88 | x = self.norm(x) 89 | ''' 90 | # 在通道的维度上进行拼接 正确率很低 91 | x = self.proj(x).flatten(2).transpose(1, 2) # [B,196,768] 92 | y = self.proj(y).flatten(2).transpose(1, 2) # [B,196,768] 93 | x = torch.cat((x, y), dim=2) # [B,196,1536] 94 | x = self.fc(x) # [B,196,768] 95 | x = self.norm(x) 96 | 97 | return x 98 | 99 | 100 | class Attention(nn.Module): 101 | def __init__(self, 102 | dim, # 输入token的dim 103 | num_heads=8, 104 | qkv_bias=False, 105 | qk_scale=None, 106 | attn_drop_ratio=0., 107 | proj_drop_ratio=0.): 108 | super(Attention, self).__init__() 109 | self.num_heads = num_heads 110 | head_dim = dim // num_heads 111 | self.scale = qk_scale or head_dim ** -0.5 112 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 113 | self.attn_drop = nn.Dropout(attn_drop_ratio) 114 | self.proj = nn.Linear(dim, dim) 115 | self.proj_drop = nn.Dropout(proj_drop_ratio) 116 | 117 | def forward(self, x): 118 | # [batch_size, num_patches + 1, total_embed_dim] 119 | B, N, C = x.shape 120 | 121 | # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] 122 | # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] 123 | # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] 124 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 125 | # [batch_size, num_heads, num_patches + 1, embed_dim_per_head] 126 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 127 | 128 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] 129 | # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] 130 | attn = (q @ k.transpose(-2, -1)) * self.scale 131 | attn = attn.softmax(dim=-1) 132 | attn = self.attn_drop(attn) 133 | 134 | # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] 135 | # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] 136 | # reshape: -> [batch_size, num_patches + 1, total_embed_dim] 137 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 138 | x = self.proj(x) 139 | x = self.proj_drop(x) 140 | return x 141 | 142 | 143 | class Mlp(nn.Module): 144 | """ 145 | MLP as used in Vision Transformer, MLP-Mixer and related networks 146 | """ 147 | 148 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 149 | super().__init__() 150 | out_features = out_features or in_features 151 | hidden_features = hidden_features or in_features 152 | self.fc1 = nn.Linear(in_features, hidden_features) 153 | self.act = act_layer() 154 | self.fc2 = nn.Linear(hidden_features, out_features) 155 | self.drop = nn.Dropout(drop) 156 | 157 | def forward(self, x): 158 | x = self.fc1(x) 159 | x = self.act(x) 160 | x = self.drop(x) 161 | x = self.fc2(x) 162 | x = self.drop(x) 163 | return x 164 | 165 | 166 | class Block(nn.Module): 167 | def __init__(self, 168 | dim, 169 | num_heads, 170 | mlp_ratio=4., 171 | qkv_bias=False, 172 | qk_scale=None, 173 | drop_ratio=0., 174 | attn_drop_ratio=0., 175 | drop_path_ratio=0., 176 | act_layer=nn.GELU, 177 | norm_layer=nn.LayerNorm): 178 | super(Block, self).__init__() 179 | self.norm1 = norm_layer(dim) 180 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 181 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) 182 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 183 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() 184 | self.norm2 = norm_layer(dim) 185 | mlp_hidden_dim = int(dim * mlp_ratio) 186 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) 187 | 188 | def forward(self, x): 189 | x = x + self.drop_path(self.attn(self.norm1(x))) 190 | x = x + self.drop_path(self.mlp(self.norm2(x))) 191 | return x 192 | 193 | 194 | class VisionTransformer(nn.Module): 195 | def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, 196 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, 197 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., 198 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, 199 | act_layer=None): 200 | """ 201 | Args: 202 | img_size (int, tuple): input image size 203 | patch_size (int, tuple): patch size 204 | in_c (int): number of input channels 205 | num_classes (int): number of classes for classification head 206 | embed_dim (int): embedding dimension 207 | depth (int): depth of transformer 208 | num_heads (int): number of attention heads 209 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 210 | qkv_bias (bool): enable bias for qkv if True 211 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 212 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 213 | distilled (bool): model includes a distillation token and head as in DeiT models 214 | drop_ratio (float): dropout rate 215 | attn_drop_ratio (float): attention dropout rate 216 | drop_path_ratio (float): stochastic depth rate 217 | embed_layer (nn.Module): patch embedding layer 218 | norm_layer: (nn.Module): normalization layer 219 | """ 220 | super(VisionTransformer, self).__init__() 221 | self.num_classes = num_classes 222 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 223 | self.num_tokens = 2 if distilled else 1 224 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 225 | act_layer = act_layer or nn.GELU 226 | 227 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) 228 | # 映射函数中损失的加权: 229 | self.a = nn.Parameter(torch.ones(1)) 230 | self.b = nn.Parameter(torch.ones(1)) 231 | self.c = nn.Parameter(torch.ones(1)) 232 | num_patches = self.patch_embed.num_patches 233 | 234 | self.cls_token = nn.Parameter(torch.zeros(1, 1, 768)) 235 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 236 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 237 | self.pos_drop = nn.Dropout(p=drop_ratio) 238 | 239 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule 240 | self.blocks = nn.Sequential(*[ 241 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 242 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], 243 | norm_layer=norm_layer, act_layer=act_layer) 244 | for i in range(depth) 245 | ]) 246 | self.norm = norm_layer(embed_dim) 247 | 248 | # Representation layer 249 | if representation_size and not distilled: 250 | self.has_logits = True 251 | self.num_features = representation_size 252 | self.pre_logits = nn.Sequential(OrderedDict([ 253 | ("fc", nn.Linear(embed_dim, representation_size)), 254 | ("act", nn.Tanh()) 255 | ])) 256 | else: 257 | self.has_logits = False 258 | self.pre_logits = nn.Identity() 259 | 260 | # Classifier head(s) 261 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 262 | self.head_dist = None 263 | if distilled: 264 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 265 | 266 | # Weight init 267 | nn.init.trunc_normal_(self.pos_embed, std=0.02) 268 | if self.dist_token is not None: 269 | nn.init.trunc_normal_(self.dist_token, std=0.02) 270 | 271 | nn.init.trunc_normal_(self.cls_token, std=0.02) 272 | self.apply(_init_vit_weights) 273 | 274 | def forward_features(self, x): 275 | # [B, C, H, W] -> [B, num_patches, embed_dim] 276 | # x = self.patch_embed(x, y) # [B, 196,768] 277 | # [1, 1, 768] -> [B, 1, 768] 278 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 279 | if self.dist_token is None: 280 | x = torch.cat((cls_token, x), dim=1) # [B, 196+1, 768] 281 | else: 282 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 283 | 284 | x = self.pos_drop(x + self.pos_embed) 285 | x = self.blocks(x) 286 | x = self.norm(x) 287 | if self.dist_token is None: 288 | return self.pre_logits(x[:, 0]) # 映射损失 289 | else: 290 | return x[:, 0], x[:, 1] 291 | 292 | def forward(self, x): 293 | x = self.forward_features(x) 294 | if self.head_dist is not None: 295 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) 296 | if self.training and not torch.jit.is_scripting(): 297 | # during inference, return the average of both classifier predictions 298 | return x, x_dist 299 | else: 300 | return (x + x_dist) / 2 301 | else: 302 | x = self.head(x) 303 | return x 304 | 305 | 306 | def _init_vit_weights(m): 307 | """ 308 | ViT weight initialization 309 | :param m: module 310 | """ 311 | if isinstance(m, nn.Linear): 312 | nn.init.trunc_normal_(m.weight, std=.01) 313 | if m.bias is not None: 314 | nn.init.zeros_(m.bias) 315 | elif isinstance(m, nn.Conv2d): 316 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 317 | if m.bias is not None: 318 | nn.init.zeros_(m.bias) 319 | elif isinstance(m, nn.LayerNorm): 320 | nn.init.zeros_(m.bias) 321 | nn.init.ones_(m.weight) 322 | 323 | 324 | def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): 325 | """ 326 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 327 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 328 | weights ported from official Google JAX impl: 329 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth 330 | """ 331 | model = VisionTransformer(img_size=224, 332 | patch_size=16, 333 | embed_dim=768, 334 | depth=12, 335 | num_heads=12, 336 | representation_size=768 if has_logits else None, 337 | num_classes=num_classes) 338 | return model 339 | 340 | 341 | def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): 342 | """ 343 | ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 344 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 345 | weights ported from official Google JAX impl: 346 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth 347 | """ 348 | model = VisionTransformer(img_size=224, 349 | patch_size=32, 350 | embed_dim=768, 351 | depth=12, 352 | num_heads=12, 353 | representation_size=768 if has_logits else None, 354 | num_classes=num_classes) 355 | return model 356 | 357 | 358 | def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): 359 | """ 360 | ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 361 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 362 | weights ported from official Google JAX impl: 363 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth 364 | """ 365 | model = VisionTransformer(img_size=224, 366 | patch_size=16, 367 | embed_dim=1024, 368 | depth=24, 369 | num_heads=16, 370 | representation_size=1024 if has_logits else None, 371 | num_classes=num_classes) 372 | return model 373 | 374 | 375 | def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): 376 | """ 377 | ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 378 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 379 | weights ported from official Google JAX impl: 380 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth 381 | """ 382 | model = VisionTransformer(img_size=224, 383 | patch_size=32, 384 | embed_dim=1024, 385 | depth=24, 386 | num_heads=16, 387 | representation_size=1024 if has_logits else None, 388 | num_classes=num_classes) 389 | return model 390 | 391 | 392 | def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True): 393 | """ 394 | ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 395 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 396 | NOTE: converted weights not currently available, too large for github release hosting. 397 | """ 398 | model = VisionTransformer(img_size=224, 399 | patch_size=14, 400 | embed_dim=1280, 401 | depth=32, 402 | num_heads=16, 403 | representation_size=1280 if has_logits else None, 404 | num_classes=num_classes) 405 | return model 406 | --------------------------------------------------------------------------------