├── README.md ├── capsule ├── __init__.py ├── model_big.py ├── test_binary_ffpp.py └── train_capsule.py ├── datasets ├── __init__.py ├── cvfunctional.py ├── cvtransforms.py ├── dataloader_imagenet_dct.py ├── dataset_imagenet.py ├── dataset_imagenet_dct.py ├── dataset_imagenet_dct_cupy.py ├── imagenet2lmdb.py └── vision.py ├── dct ├── __init__.py ├── imagenet │ ├── __init__.py │ ├── gate.py │ ├── gumbel.py │ ├── resnet.py │ ├── resnet_autsubset_inputgate.py │ └── resnet_resized.py └── utils.py ├── demo ├── demo.py └── video │ └── id0_id1_0002.mp4 ├── fwa ├── __init__.py └── classifier.py ├── imgs ├── imbalanced performance.png └── overview.png ├── main.py ├── make_train_test.py ├── meso ├── __init__.py ├── eval_meso.py ├── meso.py └── train_mesonet.py ├── models ├── __init__.py ├── convGRU.py ├── convlstm.py ├── model.py └── resnet.py ├── utils ├── __init__.py ├── auccur.py ├── aucloss.py ├── cam.py ├── config.py ├── dataloader.py ├── drawpics.py ├── eval.py ├── ff.py ├── focalloss.py ├── gradcam.py ├── mmod_human_face_detector.dat ├── test.json ├── tools.py ├── train.json ├── train_cpvr.py ├── val.json └── xcp_reg.py └── xception ├── __init__.py ├── models.py └── xception.py /README.md: -------------------------------------------------------------------------------- 1 | # Learning a Deep Dual-level Network for Robust DeepFake Detection 2 | 3 | [![Python](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/) 4 | 5 | Wenbo Pu, Jing Hu, Xin Wang, Yuezun Li, Shu Hu, Bin Zhu, Rui Song, Qi Song, Xi Wu, Siwei Lyu 6 | _________________ 7 | 8 | This repository is the official implementation of our paper "Learning a Deep Dual-level Network for Robust DeepFake Detection", which has been accepted by **Pattern Recognition**. 9 | 10 | ## Overview 11 | 12 | ![](./imgs/overview.png) 13 | 14 | ## Imbalanced Performance 15 | 16 | 17 | 18 | 19 | ## Info 20 | 21 | We provided our method, Xception6, FWA7, MesoNet8, Capsule9 and others to train and test in this repository. Xception and FWA can be train or test at `main.py` while the other methods can be found in their individual folders, such as Capsule in `capsule/`. 22 | 23 | Except the model proposed in our paper, we also provided many variants of our model, including VIT, ResVIT and DCTNet10 for replacement of ResNet, and CRNN for replacement of RNN. 24 | 25 | We also implemented Face X-ray for data-augumentation (it is not used in this paper, but we found that it can increase the performance), if you are interested in, go check `utils/dataloader.py`. 26 | 27 | The implementation of AUC loss proposed in our paper can be found in `utils/aucloss.py`. 28 | 29 | Our checkpoint can be found [here](https://drive.google.com/file/d/144ol1u4Kz4HwOsG3qvEeVqH8bpqCvaOU/view?usp=sharing). 30 | 31 | 32 | ## Requirements 33 | 34 | - Pytorch 1.4.0 35 | - Ubuntu 16.04 36 | - CUDA 10.0 37 | - Python 3.6 38 | - Dlib 19.0 39 | 40 | ## Usage 41 | 42 | - We provide a demo to show how our model work. See `demo/demo.py` 43 | ```shell 44 | python demo.py --restore_from restore_from -- path video path 45 | ``` 46 | 47 | - To train and test a model, use 48 | 49 | ```shell 50 | python main.py -i input_path -r restore_from -g gpu_id 51 | ``` 52 | 53 | - More parameters including the gamma of AUC loss can be found and adjusted in `main.py`. 54 | 55 | ## Training data preparation 56 | 57 | We provided a script to generate training and test data for this repository. Use `make_train_test.py`. This script can preprocess FaceForensics++, Celeb-DF and DFDC datasets using [MTCNN](https://github.com/ipazc/mtcnn) or [Dlib](https://github.com/davisking/dlib/). 58 | 59 | 60 | ## Citation 61 | 62 | Please kindly consider citing our paper in your publications. 63 | 64 | ```bib 65 | @article{PU2022108832, 66 | title = {Learning a deep dual-level network for robust DeepFake detection}, 67 | journal = {Pattern Recognition}, 68 | volume = {130}, 69 | pages = {108832}, 70 | year = {2022}, 71 | issn = {0031-3203}, 72 | doi = {https://doi.org/10.1016/j.patcog.2022.108832}, 73 | url = {https://www.sciencedirect.com/science/article/pii/S0031320322003132}, 74 | author = {Wenbo Pu and Jing Hu and Xin Wang and Yuezun Li and Shu Hu and Bin Zhu and Rui Song and Qi Song and Xi Wu and Siwei Lyu} 75 | } 76 | ``` 77 | _________________ 78 | 79 | ## Notice 80 | 81 | This repository is NOT for commecial use. It is provided "as it is" and we are not responsible for any subsequence of using this code. 82 | 83 | 84 | ## Thanks 85 | 86 | 6 [FaceForensics++ Learning to Detect Manipulated Facial Images](https://github.com/ondyari/FaceForensics)
87 | 7 [Exposing DeepFake Videos By Detecting Face Warping Artifacts](https://github.com/yuezunli/CVPRW2019_Face_Artifacts)
88 | 8 [MesoNet - a Compact Facial Video Forgery Detection Network](https://github.com/DariusAf/MesoNet)
89 | 9 [USE OF A CAPSULE NETWORK TO DETECT FAKE IMAGES AND VIDEOS](https://github.com/raohashim/DFD)
90 | 10 [Learning in the Frequency Domain](https://github.com/calmevtime/DCTNet) 91 | -------------------------------------------------------------------------------- /capsule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/capsule/__init__.py -------------------------------------------------------------------------------- /capsule/model_big.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018, National Institute of Informatics 3 | All rights reserved. 4 | Author: Huy H. Nguyen 5 | ----------------------------------------------------- 6 | Script for Capsule-Forensics model 7 | """ 8 | 9 | import sys 10 | 11 | sys.setrecursionlimit(15000) 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.autograd import Variable 17 | import torchvision.models as models 18 | 19 | NO_CAPS = 10 20 | 21 | 22 | class StatsNet(nn.Module): 23 | def __init__(self): 24 | super(StatsNet, self).__init__() 25 | 26 | def forward(self, x): 27 | x = x.view(x.data.shape[0], x.data.shape[1], x.data.shape[2] * x.data.shape[3]) 28 | 29 | mean = torch.mean(x, 2) 30 | std = torch.std(x, 2) 31 | 32 | return torch.stack((mean, std), dim=1) 33 | 34 | 35 | class View(nn.Module): 36 | def __init__(self, *shape): 37 | super(View, self).__init__() 38 | self.shape = shape 39 | 40 | def forward(self, input): 41 | return input.view(self.shape) 42 | 43 | 44 | class VggExtractor(nn.Module): 45 | def __init__(self, train=False): 46 | super(VggExtractor, self).__init__() 47 | 48 | self.vgg_1 = self.Vgg(models.vgg19(pretrained=True), 0, 18) 49 | if train: 50 | self.vgg_1.train(mode=True) 51 | self.freeze_gradient() 52 | else: 53 | self.vgg_1.eval() 54 | 55 | def Vgg(self, vgg, begin, end): 56 | features = nn.Sequential(*list(vgg.features.children())[begin:(end + 1)]) 57 | return features 58 | 59 | def freeze_gradient(self, begin=0, end=9): 60 | for i in range(begin, end + 1): 61 | self.vgg_1[i].requires_grad = False 62 | 63 | def forward(self, input): 64 | return self.vgg_1(input) 65 | 66 | 67 | class FeatureExtractor(nn.Module): 68 | def __init__(self): 69 | super(FeatureExtractor, self).__init__() 70 | 71 | self.capsules = nn.ModuleList([ 72 | nn.Sequential( 73 | nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1), 74 | nn.BatchNorm2d(64), 75 | nn.ReLU(), 76 | nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1), 77 | nn.BatchNorm2d(16), 78 | nn.ReLU(), 79 | StatsNet(), 80 | 81 | nn.Conv1d(2, 8, kernel_size=5, stride=2, padding=2), 82 | nn.BatchNorm1d(8), 83 | nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1), 84 | nn.BatchNorm1d(1), 85 | View(-1, 8), 86 | ) 87 | for _ in range(NO_CAPS)] 88 | ) 89 | 90 | def squash(self, tensor, dim): 91 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) 92 | scale = squared_norm / (1 + squared_norm) 93 | return scale * tensor / (torch.sqrt(squared_norm)) 94 | 95 | def forward(self, x): 96 | # outputs = [capsule(x.detach()) for capsule in self.capsules] 97 | # outputs = [capsule(x.clone()) for capsule in self.capsules] 98 | outputs = [capsule(x) for capsule in self.capsules] 99 | output = torch.stack(outputs, dim=-1) 100 | 101 | return self.squash(output, dim=-1) 102 | 103 | 104 | class RoutingLayer(nn.Module): 105 | def __init__(self, gpu_id, num_input_capsules, num_output_capsules, data_in, data_out, num_iterations): 106 | super(RoutingLayer, self).__init__() 107 | 108 | self.gpu_id = gpu_id 109 | self.num_iterations = num_iterations 110 | self.route_weights = nn.Parameter(torch.randn(num_output_capsules, num_input_capsules, data_out, data_in)) 111 | 112 | def squash(self, tensor, dim): 113 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) 114 | scale = squared_norm / (1 + squared_norm) 115 | return scale * tensor / (torch.sqrt(squared_norm)) 116 | 117 | def forward(self, x, random, dropout): 118 | # x[b, data, in_caps] 119 | 120 | x = x.transpose(2, 1) 121 | # x[b, in_caps, data] 122 | 123 | if random: 124 | noise = Variable(0.01 * torch.randn(*self.route_weights.size())) 125 | if self.gpu_id >= 0: 126 | noise = noise.cuda(self.gpu_id) 127 | route_weights = self.route_weights + noise 128 | else: 129 | route_weights = self.route_weights 130 | 131 | priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None] 132 | 133 | # route_weights [out_caps , 1 , in_caps , data_out , data_in] 134 | # x [ 1 , b , in_caps , data_in , 1 ] 135 | # priors [out_caps , b , in_caps , data_out, 1 ] 136 | 137 | priors = priors.transpose(1, 0) 138 | # priors[b, out_caps, in_caps, data_out, 1] 139 | 140 | if dropout > 0.0: 141 | drop = Variable(torch.FloatTensor(*priors.size()).bernoulli(1.0 - dropout)) 142 | if self.gpu_id >= 0: 143 | drop = drop.cuda(self.gpu_id) 144 | priors = priors * drop 145 | 146 | logits = Variable(torch.zeros(*priors.size())) 147 | # logits[b, out_caps, in_caps, data_out, 1] 148 | 149 | if self.gpu_id >= 0: 150 | logits = logits.cuda(self.gpu_id) 151 | 152 | num_iterations = self.num_iterations 153 | 154 | for i in range(num_iterations): 155 | probs = F.softmax(logits, dim=2) 156 | outputs = self.squash((probs * priors).sum(dim=2, keepdim=True), dim=3) 157 | 158 | if i != self.num_iterations - 1: 159 | delta_logits = priors * outputs 160 | logits = logits + delta_logits 161 | 162 | # outputs[b, out_caps, 1, data_out, 1] 163 | outputs = outputs.squeeze() 164 | 165 | if len(outputs.shape) == 3: 166 | outputs = outputs.transpose(2, 1).contiguous() 167 | else: 168 | outputs = outputs.unsqueeze_(dim=0).transpose(2, 1).contiguous() 169 | # outputs[b, data_out, out_caps] 170 | 171 | return outputs 172 | 173 | 174 | class CapsuleNet(nn.Module): 175 | def __init__(self, num_class, gpu_id): 176 | super(CapsuleNet, self).__init__() 177 | 178 | self.num_class = num_class 179 | self.fea_ext = FeatureExtractor() 180 | self.fea_ext.apply(self.weights_init) 181 | 182 | self.routing_stats = RoutingLayer(gpu_id=gpu_id, num_input_capsules=NO_CAPS, num_output_capsules=num_class, 183 | data_in=8, data_out=4, num_iterations=2) 184 | 185 | def weights_init(self, m): 186 | classname = m.__class__.__name__ 187 | if classname.find('Conv') != -1: 188 | m.weight.data.normal_(0.0, 0.02) 189 | elif classname.find('BatchNorm') != -1: 190 | m.weight.data.normal_(1.0, 0.02) 191 | m.bias.data.fill_(0) 192 | 193 | def forward(self, x, random=False, dropout=0.0): 194 | 195 | z = self.fea_ext(x) 196 | z = self.routing_stats(z, random, dropout=dropout) 197 | # z[b, data, out_caps] 198 | 199 | # classes = F.softmax(z, dim=-1) 200 | 201 | # class_ = classes.detach() 202 | # class_ = class_.mean(dim=1) 203 | 204 | # return classes, class_ 205 | 206 | classes = F.softmax(z, dim=-1) 207 | class_ = classes.detach() 208 | class_ = class_.mean(dim=1) 209 | 210 | return z, class_ 211 | 212 | 213 | class CapsuleLoss(nn.Module): 214 | def __init__(self, gpu_id): 215 | super(CapsuleLoss, self).__init__() 216 | self.cross_entropy_loss = nn.CrossEntropyLoss() 217 | 218 | if gpu_id >= 0: 219 | self.cross_entropy_loss.cuda(gpu_id) 220 | 221 | def forward(self, classes, labels): 222 | loss_t = self.cross_entropy_loss(classes[:, 0, :], labels) 223 | 224 | for i in range(classes.size(1) - 1): 225 | loss_t = loss_t + self.cross_entropy_loss(classes[:, i + 1, :], labels) 226 | 227 | return loss_t 228 | -------------------------------------------------------------------------------- /capsule/test_binary_ffpp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019, National Institute of Informatics 3 | All rights reserved. 4 | Author: Huy H. Nguyen 5 | ----------------------------------------------------- 6 | Script for testing Capsule-Forensics-v2 on FaceForensics++ database (Real, DeepFakes, Face2Face, FaceSwap) 7 | """ 8 | 9 | import sys 10 | sys.setrecursionlimit(15000) 11 | import os 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import numpy as np 15 | from torch.autograd import Variable 16 | import torch.utils.data 17 | import torchvision.datasets as dset 18 | from torch.utils.data import DataLoader 19 | import torchvision.transforms as transforms 20 | from tqdm import tqdm 21 | import argparse 22 | from sklearn import metrics 23 | from scipy.optimize import brentq 24 | from scipy.interpolate import interp1d 25 | from sklearn.metrics import roc_curve 26 | import model_big 27 | import pandas 28 | 29 | from utils.dataloader import FrameDataset 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--dataset', default ='databases/faceforensicspp', help='path to dataset') 33 | parser.add_argument('--test_set', default ='test', help='test set') 34 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=0) 35 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 36 | parser.add_argument('--imageSize', type=int, default=300, help='the height / width of the input image to network') 37 | parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID') 38 | parser.add_argument('--outf', default='checkpoints/binary_faceforensicspp', help='folder to output model checkpoints') 39 | parser.add_argument('--random', action='store_true', default=False, help='enable randomness for routing matrix') 40 | parser.add_argument('--id', type=int, default=21, help='checkpoint ID') 41 | 42 | opt = parser.parse_args() 43 | print(opt) 44 | 45 | if __name__ == '__main__': 46 | 47 | # text_writer = open(os.path.join(opt.outf, 'test.txt'), 'w') 48 | 49 | transform_fwd = transforms.Compose([ 50 | transforms.Resize((opt.imageSize, opt.imageSize)), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 53 | ]) 54 | 55 | 56 | # dataset_test = dset.ImageFolder(root=os.path.join(opt.dataset, opt.test_set), transform=transform_fwd) 57 | # assert dataset_test 58 | # dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=opt.batchSize, shuffle=False, num_workers=int(opt.workers)) 59 | # dataloaders = {} 60 | # for name in ['train', 'test']: 61 | # raw_data = pandas.read_csv(os.path.join(opt.dataset, '%s.csv' % name)) 62 | # dataloaders[name] = DataLoader(FrameDataset(raw_data.to_numpy()), **config.dataset_params) 63 | raw_data = pandas.read_csv(os.path.join(opt.dataset, 'test.csv')) 64 | dataloader_test = DataLoader(FrameDataset(raw_data.to_numpy()), 65 | batch_size=opt.batchSize, 66 | shuffle=True, 67 | num_workers=4, 68 | pin_memory=False) 69 | vgg_ext = model_big.VggExtractor() 70 | capnet = model_big.CapsuleNet(2, opt.gpu_id) 71 | 72 | capnet.load_state_dict(torch.load(os.path.join(opt.outf))) 73 | capnet.eval() 74 | 75 | if opt.gpu_id >= 0: 76 | vgg_ext.cuda(opt.gpu_id) 77 | capnet.cuda(opt.gpu_id) 78 | 79 | 80 | ################################################################################## 81 | 82 | tol_label = np.array([], dtype=np.float) 83 | tol_pred = np.array([], dtype=np.float) 84 | tol_pred_prob = np.array([], dtype=np.float) 85 | 86 | count = 0 87 | loss_test = 0 88 | 89 | for img_data, labels_data in tqdm(dataloader_test): 90 | 91 | labels_data[labels_data > 1] = 1 92 | img_label = labels_data.numpy().astype(np.float) 93 | 94 | if opt.gpu_id >= 0: 95 | img_data = img_data.cuda(opt.gpu_id) 96 | labels_data = labels_data.cuda(opt.gpu_id) 97 | 98 | input_v = Variable(img_data) 99 | 100 | x = vgg_ext(input_v) 101 | classes, class_ = capnet(x, random=opt.random) 102 | 103 | output_dis = class_.data.cpu() 104 | output_pred = np.zeros((output_dis.shape[0]), dtype=np.float) 105 | 106 | for i in range(output_dis.shape[0]): 107 | if output_dis[i,1] >= output_dis[i,0]: 108 | output_pred[i] = 1.0 109 | else: 110 | output_pred[i] = 0.0 111 | 112 | tol_label = np.concatenate((tol_label, img_label)) 113 | tol_pred = np.concatenate((tol_pred, output_pred)) 114 | 115 | pred_prob = torch.softmax(output_dis, dim=1) 116 | tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:, 1].data.numpy())) 117 | 118 | count += 1 119 | 120 | acc_test = metrics.accuracy_score(tol_label, tol_pred) 121 | auc_test = metrics.roc_auc_score(tol_label, tol_pred_prob) 122 | f1_test = metrics.f1_score(tol_label, tol_pred) 123 | recall_test = metrics.recall_score(tol_label, tol_pred) 124 | precision = metrics.precision_score(tol_label, tol_pred) 125 | loss_test /= count 126 | 127 | fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1) 128 | np.save('./m/cap/f_fpr.npy', fpr) 129 | np.save('./m/cap/f_tpr.npy', tpr) 130 | # eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 131 | # 132 | # fnr = 1 - tpr 133 | # hter = (fpr + fnr)/2 134 | # print('[Epoch %d] Train loss: %.4f acc: %.2f | Test loss: %.4f acc: %.2f auc: %.2f' 135 | # % (opt.id, acc_test * 100, loss_test, acc_test * 100, auc_test * 100)) 136 | print('[Epoch %d] Test acc: %.2f AUC: %.2f f1: %.2f recall:%.2f precision:%.2f' 137 | % (opt.id, acc_test*100, auc_test*100, f1_test, recall_test, precision)) 138 | # text_writer.write('%d,%.2f,%.2f\n'% (opt.id, acc_test*100, eer*100)) 139 | # 140 | # text_writer.flush() 141 | # text_writer.close() 142 | -------------------------------------------------------------------------------- /capsule/train_capsule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019, National Institute of Informatics 3 | All rights reserved. 4 | Author: Huy H. Nguyen 5 | ----------------------------------------------------- 6 | Script for training Capsule-Forensics-v2 on FaceForensics++ database (Real, DeepFakes, Face2Face, FaceSwap) 7 | """ 8 | 9 | import sys 10 | 11 | sys.setrecursionlimit(15000) 12 | import os 13 | import random 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import numpy as np 17 | from torch.autograd import Variable 18 | from torch.optim import Adam 19 | # import torchvision.transforms as transforms 20 | from torch.utils.data import DataLoader 21 | from tqdm import tqdm 22 | import argparse 23 | from sklearn import metrics 24 | import model_big 25 | import pandas 26 | 27 | from utils.dataloader import FrameDataset 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--dataset', default='', help='path to root dataset') 32 | # parser.add_argument('--train_set', default='train', help='train set') 33 | # parser.add_argument('--val_set', default='validation', help='validation set') 34 | # parser.add_argument('--workers', type=int, help='number of data loading workers', default=0) 35 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 36 | # parser.add_argument('--imageSize', type=int, default=300, help='the height / width of the input image to network') 37 | parser.add_argument('--niter', type=int, default=20, help='number of epochs to train for') 38 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 39 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam') 40 | parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID') 41 | parser.add_argument('--resume', type=int, default=0, help="choose a epochs to resume from (0 to train from scratch)") 42 | parser.add_argument('--outf', default='capsule/binary_faceforensicspp', help='folder to output model checkpoints') 43 | parser.add_argument('--disable_random', action='store_true', default=False, 44 | help='disable randomness for routing matrix') 45 | parser.add_argument('--dropout', type=float, default=0.05, help='dropout percentage') 46 | parser.add_argument('--manualSeed', type=int, help='manual seed') 47 | 48 | opt = parser.parse_args() 49 | print(opt) 50 | 51 | opt.random = not opt.disable_random 52 | 53 | if __name__ == "__main__": 54 | 55 | if opt.manualSeed is None: 56 | opt.manualSeed = random.randint(1, 10000) 57 | print("Random Seed: ", opt.manualSeed) 58 | random.seed(opt.manualSeed) 59 | torch.manual_seed(opt.manualSeed) 60 | 61 | if opt.gpu_id >= 0: 62 | torch.cuda.manual_seed_all(opt.manualSeed) 63 | cudnn.benchmark = True 64 | 65 | if opt.resume > 0: 66 | text_writer = open('train_capsule.csv', 'a') 67 | else: 68 | text_writer = open('train_capsule.csv', 'w') 69 | 70 | vgg_ext = model_big.VggExtractor() 71 | capnet = model_big.CapsuleNet(2, opt.gpu_id) 72 | capsule_loss = model_big.CapsuleLoss(opt.gpu_id) 73 | 74 | if opt.gpu_id >= 0: 75 | capnet.cuda(opt.gpu_id) 76 | vgg_ext.cuda(opt.gpu_id) 77 | capsule_loss.cuda(opt.gpu_id) 78 | 79 | capnet.load_state_dict(torch.load('/home/asus/Code/pvc/capsule_8.pt')) 80 | 81 | optimizer = Adam(capnet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 82 | 83 | if opt.resume > 0: 84 | capnet.load_state_dict(torch.load(os.path.join(opt.outf, 'capsule_' + str(opt.resume) + '.pt'))) 85 | capnet.train(mode=True) 86 | optimizer.load_state_dict(torch.load(os.path.join(opt.outf, 'optim_' + str(opt.resume) + '.pt'))) 87 | 88 | if opt.gpu_id >= 0: 89 | for state in optimizer.state.values(): 90 | for k, v in state.items(): 91 | if isinstance(v, torch.Tensor): 92 | state[k] = v.cuda(opt.gpu_id) 93 | 94 | 95 | # 96 | # transform_fwd = transforms.Compose([ 97 | # transforms.Resize((opt.imageSize, opt.imageSize)), 98 | # transforms.ToTensor(), 99 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 100 | # ]) 101 | 102 | # dataset_train = dset.ImageFolder(root=os.path.join(opt.dataset, opt.train_set), transform=transform_fwd) 103 | # assert dataset_train 104 | # dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=opt.batchSize, shuffle=True, 105 | # num_workers=int(opt.workers)) 106 | # 107 | # dataset_val = dset.ImageFolder(root=os.path.join(opt.dataset, opt.val_set), transform=transform_fwd) 108 | # assert dataset_val 109 | # dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=opt.batchSize, shuffle=False, 110 | # num_workers=int(opt.workers)) 111 | dataloaders = {} 112 | for name in ['train', 'test']: 113 | raw_data = pandas.read_csv(os.path.join(opt.dataset, '%s.csv' % name)) 114 | dataloaders[name] = DataLoader(FrameDataset(raw_data.to_numpy()), 115 | batch_size=opt.batch_size, 116 | shuffle=True, 117 | num_workers=4, 118 | pin_memory=False) 119 | 120 | for epoch in range(opt.resume + 1, opt.niter + 1): 121 | count = 0 122 | loss_train = 0 123 | loss_test = 0 124 | 125 | tol_label = np.array([], dtype=np.float) 126 | tol_pred = np.array([], dtype=np.float) 127 | 128 | for img_data, labels_data in tqdm(dataloaders['train']): 129 | 130 | labels_data[labels_data > 1] = 1 131 | img_label = labels_data.numpy().astype(np.float) 132 | optimizer.zero_grad() 133 | 134 | if opt.gpu_id >= 0: 135 | img_data = img_data.cuda(opt.gpu_id) 136 | labels_data = labels_data.cuda(opt.gpu_id) 137 | 138 | input_v = Variable(img_data) 139 | x = vgg_ext(input_v) 140 | classes, class_ = capnet(x, random=opt.random, dropout=opt.dropout) 141 | 142 | loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False)) 143 | loss_dis_data = loss_dis.item() 144 | 145 | loss_dis.backward() 146 | optimizer.step() 147 | 148 | output_dis = class_.data.cpu().numpy() 149 | output_pred = np.zeros((output_dis.shape[0]), dtype=np.float) 150 | 151 | for i in range(output_dis.shape[0]): 152 | if output_dis[i, 1] >= output_dis[i, 0]: 153 | output_pred[i] = 1.0 154 | else: 155 | output_pred[i] = 0.0 156 | 157 | tol_label = np.concatenate((tol_label, img_label)) 158 | tol_pred = np.concatenate((tol_pred, output_pred)) 159 | 160 | loss_train += loss_dis_data 161 | count += 1 162 | 163 | acc_train = metrics.accuracy_score(tol_label, tol_pred) 164 | loss_train /= count 165 | 166 | ######################################################################## 167 | 168 | # do checkpointing & validation 169 | torch.save(capnet.state_dict(), os.path.join(opt.outf, 'capsule_%d.pt' % epoch)) 170 | torch.save(optimizer.state_dict(), os.path.join(opt.outf, 'optim_%d.pt' % epoch)) 171 | 172 | capnet.eval() 173 | 174 | tol_label = np.array([], dtype=np.float) 175 | tol_pred = np.array([], dtype=np.float) 176 | 177 | count = 0 178 | 179 | for img_data, labels_data in dataloaders['test']: 180 | 181 | labels_data[labels_data > 1] = 1 182 | img_label = labels_data.numpy().astype(np.float) 183 | 184 | if opt.gpu_id >= 0: 185 | img_data = img_data.cuda(opt.gpu_id) 186 | labels_data = labels_data.cuda(opt.gpu_id) 187 | 188 | input_v = Variable(img_data) 189 | 190 | x = vgg_ext(input_v) 191 | classes, class_ = capnet(x, random=False) 192 | 193 | loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False)) 194 | loss_dis_data = loss_dis.item() 195 | output_dis = class_.data.cpu().numpy() 196 | 197 | output_pred = np.zeros((output_dis.shape[0]), dtype=np.float) 198 | 199 | for i in range(output_dis.shape[0]): 200 | if output_dis[i, 1] >= output_dis[i, 0]: 201 | output_pred[i] = 1.0 202 | else: 203 | output_pred[i] = 0.0 204 | 205 | tol_label = np.concatenate((tol_label, img_label)) 206 | tol_pred = np.concatenate((tol_pred, output_pred)) 207 | 208 | loss_test += loss_dis_data 209 | count += 1 210 | 211 | acc_test = metrics.accuracy_score(tol_label, tol_pred) 212 | auc_test = metrics.roc_auc_score(tol_label, tol_pred) 213 | f1_test = metrics.f1_score(tol_label, tol_pred) 214 | recall_test = metrics.recall_score(tol_label, tol_pred) 215 | precision = metrics.precision_score(tol_label, tol_pred) 216 | 217 | loss_test /= count 218 | 219 | print('[Epoch %d] Train loss: %.4f acc: %.2f | Test loss: %.4f acc: %.2f auc: %.2f' 220 | % (epoch, loss_train, acc_train * 100, loss_test, acc_test * 100, auc_test * 100)) 221 | 222 | text_writer.write('%d,%.4f,%.2f,%.4f,%.2f,%.2f,%.2f,%.2f,%.2f\n' 223 | % (epoch, loss_train, acc_train * 100, loss_test, acc_test * 100, auc_test * 100, 224 | f1_test * 100, recall_test * 100, precision * 100)) 225 | 226 | text_writer.flush() 227 | capnet.train(mode=True) 228 | 229 | text_writer.close() 230 | -------------------------------------------------------------------------------- /datasets/dataset_imagenet_dct_cupy.py: -------------------------------------------------------------------------------- 1 | # Optimized for DCT 2 | # Upsampling in the compressed domain 3 | import os 4 | import sys 5 | from datasets.vision import VisionDataset 6 | from PIL import Image 7 | import cv2 8 | import os.path 9 | import numpy as np 10 | import torch 11 | from turbojpeg import TurboJPEG 12 | from datasets import train_y_mean_resized, train_y_std_resized, train_cb_mean_resized, train_cb_std_resized, \ 13 | train_cr_mean_resized, train_cr_std_resized 14 | 15 | def has_file_allowed_extension(filename, extensions): 16 | """Checks if a file is an allowed extension. 17 | 18 | Args: 19 | filename (string): path to a file 20 | extensions (tuple of strings): extensions to consider (lowercase) 21 | 22 | Returns: 23 | bool: True if the filename ends with one of given extensions 24 | """ 25 | return filename.lower().endswith(extensions) 26 | 27 | 28 | def is_image_file(filename): 29 | """Checks if a file is an allowed image extension. 30 | 31 | Args: 32 | filename (string): path to a file 33 | 34 | Returns: 35 | bool: True if the filename ends with a known image extension 36 | """ 37 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 38 | 39 | 40 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): 41 | images = [] 42 | dir = os.path.expanduser(dir) 43 | if not ((extensions is None) ^ (is_valid_file is None)): 44 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 45 | if extensions is not None: 46 | def is_valid_file(x): 47 | return has_file_allowed_extension(x, extensions) 48 | for target in sorted(class_to_idx.keys()): 49 | d = os.path.join(dir, target) 50 | if not os.path.isdir(d): 51 | continue 52 | for root, _, fnames in sorted(os.walk(d)): 53 | for fname in sorted(fnames): 54 | path = os.path.join(root, fname) 55 | if is_valid_file(path): 56 | item = (path, class_to_idx[target]) 57 | images.append(item) 58 | 59 | return images 60 | 61 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 62 | 63 | def pil_loader(path): 64 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 65 | with open(path, 'rb') as f: 66 | img = Image.open(f) 67 | return img.convert('RGB') 68 | 69 | def accimage_loader(path): 70 | import accimage 71 | try: 72 | return accimage.Image(path) 73 | except IOError: 74 | # Potentially a decoding problem, fall back to PIL.Image 75 | return pil_loader(path) 76 | 77 | def opencv_loader(path, colorSpace='YCrCb'): 78 | image = cv2.imread(str(path)) 79 | # cv2.imwrite('/mnt/ssd/kai.x/work/code/iftc/datasets/cvtransforms/test/raw.jpg', image) 80 | if colorSpace == "YCrCb": 81 | image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb) 82 | # cv2.imwrite('/mnt/ssd/kai.x/work/code/iftc/datasets/cvtransforms/test/ycbcr.jpg', image) 83 | elif colorSpace == 'RGB': 84 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 85 | return image 86 | 87 | def default_loader(path, backend='opencv', colorSpace='YCrCb'): 88 | from torchvision import get_image_backend 89 | if backend == 'opencv': 90 | return opencv_loader(path, colorSpace=colorSpace) 91 | elif get_image_backend() == 'accimage' and backend == 'acc': 92 | return accimage_loader(path) 93 | elif backend == 'pil': 94 | return pil_loader(path) 95 | else: 96 | raise NotImplementedError 97 | 98 | class DatasetFolderDCT(VisionDataset): 99 | def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None, subset=0): 100 | super(DatasetFolderDCT, self).__init__(root) 101 | self.transform = transform 102 | self.target_transform = target_transform 103 | classes, class_to_idx = self._find_classes(self.root) 104 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 105 | if len(samples) == 0: 106 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 107 | "Supported extensions are: " + ",".join(extensions))) 108 | 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.classes = classes 113 | self.class_to_idx = class_to_idx 114 | self.samples = samples 115 | self.targets = [s[1] for s in samples] 116 | self.subset = list(map(int, subset.split(','))) if subset else [] 117 | 118 | def _find_classes(self, dir): 119 | if sys.version_info >= (3, 5): 120 | # Faster and available in Python 3.5 and above 121 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 122 | else: 123 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 124 | classes.sort() 125 | class_to_idx = {classes[i]: i for i in range(len(classes))} 126 | return classes, class_to_idx 127 | 128 | def __getitem__(self, index): 129 | path, target = self.samples[index] 130 | # sample = self.loader(path, backend='opencv', colorSpace='YCrCb') 131 | 132 | sample = self.loader(path, backend='opencv', colorSpace='BGR') 133 | 134 | # with open(path, 'rb') as src: 135 | # buffer = src.read() 136 | # dct_y_bak, dct_cb_bak, dct_cr_bak = loads(buffer) 137 | 138 | if self.transform is not None: 139 | dct_y, dct_cb, dct_cr = self.transform(sample) 140 | 141 | # sample_resize = sample.resize((224*2, 224*2), resample=0) 142 | # PIL to numpy 143 | # sample = np.asarray(sample, dtype="uint8") 144 | # RGB to BGR 145 | # sample = sample[:, :, ::-1] 146 | # JPEG Encode 147 | # sample = np.ascontiguousarray(sample, dtype="uint8") 148 | # sample = self.jpeg.encode(sample, quality=100, jpeg_subsample=2) 149 | # dct_y, dct_cb, dct_cr = loads(sample) # 28 150 | 151 | # sample_resize = np.asarray(sample_resize) 152 | # sample_resize = sample_resize[:, :, ::-1] 153 | # sample_resize = np.ascontiguousarray(sample_resize, dtype="uint8") 154 | # sample_resize = self.jpeg.encode(sample_resize, quality=100) 155 | # _, dct_cb_resize, dct_cr_resize = loads(sample_resize) # 28 156 | # dct_cb_resize, dct_cr_resize = torch.from_numpy(dct_cb_resize).permute(2, 0, 1).float(), \ 157 | # torch.from_numpy(dct_cr_resize).permute(2, 0, 1).float() 158 | 159 | # dct_y_unnormalized, dct_cb_unnormalized, dct_cr_unnormalized = loads(sample, normalized=False) # 28 160 | # dct_y_normalized, dct_cb_normalized, dct_cr_normalized = loads(sample, normalized=True) # 28 161 | # total_y = (dct_y-dct_y_bak).sum() 162 | # total_cb = (dct_cb-dct_cb_bak).sum() 163 | # total_cr = (dct_cr-dct_cr_bak).sum() 164 | # print('{}, {}, {}'.format(total_y, total_cb, total_cr)) 165 | # dct_y, dct_cb, dct_cr = torch.from_numpy(dct_y).permute(2, 0, 1).float(), \ 166 | # torch.from_numpy(dct_cb).permute(2, 0, 1).float(), \ 167 | # torch.from_numpy(dct_cr).permute(2, 0, 1).float() 168 | 169 | # transform = transforms.Resize(28, interpolation=2) 170 | # dct_cb_resize2 = [transform(Image.fromarray(dct_c.numpy())) for dct_c in dct_cb] 171 | 172 | if self.subset: 173 | dct_y, dct_cb, dct_cr = dct_y[self.subset[0]:self.subset[1]], dct_cb[self.subset[0]:self.subset[1]], \ 174 | dct_cr[self.subset[0]:self.subset[1]] 175 | 176 | return dct_y, dct_cb, dct_cr, target 177 | 178 | def __len__(self): 179 | return len(self.samples) 180 | 181 | class ImageFolderDCT(DatasetFolderDCT): 182 | def __init__(self, root, transform=None, target_transform=None, 183 | loader=default_loader, is_valid_file=None, subset=None): 184 | super(ImageFolderDCT, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 185 | transform=transform, 186 | target_transform=target_transform, 187 | is_valid_file=is_valid_file, subset=subset) 188 | self.imgs = self.samples 189 | 190 | 191 | if __name__ == '__main__': 192 | dataset = 'imagenet' 193 | 194 | import torch 195 | import datasets.cvtransforms as transforms 196 | import matplotlib.pyplot as plt 197 | from sklearn.preprocessing import minmax_scale 198 | 199 | # jpeg_encoder = TurboJPEG('/home/kai.x/work/local/lib/libturbojpeg.so') 200 | jpeg_encoder = TurboJPEG('/usr/lib/libturbojpeg.so') 201 | if dataset == 'imagenet': 202 | input_normalize = [] 203 | input_normalize_y = transforms.Normalize(mean=train_y_mean_resized, 204 | std=train_y_std_resized) 205 | input_normalize_cb = transforms.Normalize(mean=train_cb_mean_resized, 206 | std=train_cb_std_resized) 207 | input_normalize_cr = transforms.Normalize(mean=train_cr_mean_resized, 208 | std=train_cr_std_resized) 209 | input_normalize.append(input_normalize_y) 210 | input_normalize.append(input_normalize_cb) 211 | input_normalize.append(input_normalize_cr) 212 | val_loader = torch.utils.data.DataLoader( 213 | # ImageFolderDCT('/mnt/ssd/kai.x/dataset/ILSVRC2012/val', transforms.Compose([ 214 | ImageFolderDCT('/storage-t1/user/kaixu/datasets/ILSVRC2012/val', transforms.Compose([ 215 | transforms.ToYCrCb(), 216 | transforms.TransformDCT(), 217 | transforms.UpsampleDCT(T=896, debug=False), 218 | transforms.CenterCropDCT(112), 219 | transforms.ToTensorDCT(), 220 | transforms.NormalizeDCT( 221 | train_y_mean_resized, train_y_std_resized, 222 | train_cb_mean_resized, train_cb_std_resized, 223 | train_cr_mean_resized, train_cr_std_resized), 224 | ])), 225 | batch_size=1, shuffle=False, 226 | num_workers=1, pin_memory=False) 227 | 228 | train_dataset = ImageFolderDCT('/storage-t1/user/kaixu/datasets/ILSVRC2012/train', transforms.Compose([ 229 | transforms.RandomResizedCrop(224), 230 | transforms.RandomHorizontalFlip(), 231 | transforms.ToYCrCb(), 232 | transforms.ChromaSubsample(), 233 | transforms.UpsampleDCT(size=224, T=896, cuda=True, debug=False), 234 | transforms.ToTensorDCT(), 235 | transforms.NormalizeDCT( 236 | train_y_mean_resized, train_y_std_resized, 237 | train_cb_mean_resized, train_cb_std_resized, 238 | train_cr_mean_resized, train_cr_std_resized), 239 | ])) 240 | 241 | train_loader = torch.utils.data.DataLoader( 242 | train_dataset, 243 | batch_size=1, shuffle=False, 244 | num_workers=1, pin_memory=False) 245 | 246 | from torchvision.utils import save_image 247 | dct_y_mean_total, dct_y_std_total = [], [] 248 | # for batch_idx, (dct_y, dct_cb, dct_cr, targets) in enumerate(val_loader): 249 | for batch_idx, (dct_y, dct_cb, dct_cr, targets) in enumerate(train_loader): 250 | coef = dct_y.numpy() 251 | dct_y_mean, dct_y_std = [], [] 252 | 253 | for c in coef: 254 | c = c.reshape((64, -1)) 255 | dct_y_mean.append([np.mean(x) for x in c]) 256 | dct_y_std.append([np.std(x) for x in c]) 257 | 258 | dct_y_mean_np = np.asarray(dct_y_mean).mean(axis=0) 259 | dct_y_std_np = np.asarray(dct_y_std).mean(axis=0) 260 | dct_y_mean_total.append(dct_y_mean_np) 261 | dct_y_std_total.append(dct_y_std_np) 262 | # print('The mean of dct_y is: {}'.format(dct_y_mean_np)) 263 | # print('The std of dct_y is: {}'.format(dct_y_std_np)) 264 | 265 | print('The mean of dct_y is: {}'.format(np.asarray(dct_y_mean_total).mean(axis=0))) 266 | print('The std of dct_y is: {}'.format(np.asarray(dct_y_std_total).mean(axis=0))) 267 | 268 | 269 | -------------------------------------------------------------------------------- /datasets/imagenet2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import os, sys 4 | import os.path as osp 5 | from PIL import Image 6 | import six 7 | import string 8 | 9 | import lmdb 10 | import pickle 11 | import msgpack 12 | import tqdm 13 | import pyarrow as pa 14 | import bz2 15 | 16 | import torch 17 | import torch.utils.data as data 18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 19 | import datasets.cvtransforms as transforms 20 | from datasets.dataset_imagenet_dct import ImageFolderDCT 21 | 22 | class ImageFolderLMDB(data.Dataset): 23 | def __init__(self, db_path, transform=None, target_transform=None): 24 | self.db_path = db_path 25 | self.env = lmdb.open(db_path, subdir=osp.isdir(db_path), 26 | readonly=True, lock=False, 27 | readahead=False, meminit=False) 28 | with self.env.begin(write=False) as txn: 29 | # self.length = txn.stat()['entries'] - 1 30 | self.length = txn.get(b'__len__') 31 | self.keys = msgpack.loads(txn.get(b'__keys__')) 32 | 33 | self.transform = transform 34 | self.target_transform = target_transform 35 | 36 | def __getitem__(self, index): 37 | img, target = None, None 38 | env = self.env 39 | with env.begin(write=False) as txn: 40 | byteflow = txn.get(self.keys[index]) 41 | unpacked = msgpack.loads(byteflow) 42 | 43 | # load image 44 | imgbuf = unpacked[0] 45 | buf = six.BytesIO() 46 | buf.write(imgbuf) 47 | buf.seek(0) 48 | img = Image.open(buf).convert('RGB') 49 | 50 | # load label 51 | target = unpacked[1] 52 | 53 | if self.transform is not None: 54 | img = self.transform(img) 55 | 56 | if self.target_transform is not None: 57 | target = self.target_transform(target) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return self.length 63 | 64 | def __repr__(self): 65 | return self.__class__.__name__ + ' (' + self.db_path + ')' 66 | 67 | 68 | class ImageFolderLMDB_old(data.Dataset): 69 | def __init__(self, db_path, transform=None, target_transform=None): 70 | import lmdb 71 | self.db_path = db_path 72 | self.env = lmdb.open(db_path, subdir=osp.isdir(db_path), 73 | readonly=True, lock=False, 74 | readahead=False, meminit=False) 75 | with self.env.begin(write=False) as txn: 76 | self.length = txn.stat()['entries'] - 1 77 | self.keys = msgpack.loads(txn.get(b'__keys__')) 78 | # cache_file = '_cache_' + db_path.replace('/', '_') 79 | # if os.path.isfile(cache_file): 80 | # self.keys = pickle.load(open(cache_file, "rb")) 81 | # else: 82 | # with self.env.begin(write=False) as txn: 83 | # self.keys = [key for key, _ in txn.cursor()] 84 | # pickle.dump(self.keys, open(cache_file, "wb")) 85 | self.transform = transform 86 | self.target_transform = target_transform 87 | 88 | def __getitem__(self, index): 89 | img, target = None, None 90 | env = self.env 91 | with env.begin(write=False) as txn: 92 | byteflow = txn.get(self.keys[index]) 93 | unpacked = msgpack.loads(byteflow) 94 | imgbuf = unpacked[0][b'data'] 95 | buf = six.BytesIO() 96 | buf.write(imgbuf) 97 | buf.seek(0) 98 | img = Image.open(buf).convert('RGB') 99 | target = unpacked[1] 100 | 101 | if self.transform is not None: 102 | img = self.transform(img) 103 | 104 | if self.target_transform is not None: 105 | target = self.target_transform(target) 106 | 107 | return img, target 108 | 109 | def __len__(self): 110 | return self.length 111 | 112 | def __repr__(self): 113 | return self.__class__.__name__ + ' (' + self.db_path + ')' 114 | 115 | 116 | def raw_reader(path): 117 | with open(path, 'rb') as f: 118 | bin_data = f.read() 119 | return bin_data 120 | 121 | 122 | def dumps_pyarrow(obj): 123 | """ 124 | Serialize an object. 125 | 126 | Returns: 127 | Implementation-dependent bytes-like object 128 | """ 129 | return pa.serialize(obj).to_buffer() 130 | 131 | def folder2lmdb(dpath, name="train", write_frequency=1): 132 | directory = osp.expanduser(osp.join(dpath, name)) 133 | print("Loading dataset from %s" % directory) 134 | 135 | dataset = ImageFolderDCT('/ILSVRC2012/train', transforms.Compose([ 136 | transforms.DCTFlatten2D(), 137 | transforms.UpsampleDCT(upscale_ratio_h=4, upscale_ratio_w=4, debug=False), 138 | transforms.ToTensorDCT(), 139 | transforms.SubsetDCT(channels=32), 140 | ]), backend='dct') 141 | 142 | data_loader = torch.utils.data.DataLoader( 143 | dataset, 144 | num_workers=0, 145 | ) 146 | 147 | lmdb_path = osp.join(dpath, "%s.lmdb" % name) 148 | isdir = os.path.isdir(lmdb_path) 149 | 150 | print("Generate LMDB to %s" % lmdb_path) 151 | db = lmdb.open(lmdb_path, subdir=isdir, 152 | map_size=1281167*224*224*32*10, readonly=False, 153 | # map_size=1099511627776 * 2, readonly=False, 154 | meminit=False, map_async=True) 155 | 156 | txn = db.begin(write=True) 157 | for idx, (image, label) in enumerate(data_loader): 158 | image = image.numpy() 159 | label = label.numpy() 160 | txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((bz2.compress(image), label))) 161 | if idx % write_frequency == 0: 162 | print("[%d/%d]" % (idx, len(data_loader))) 163 | txn.commit() 164 | txn = db.begin(write=True) 165 | 166 | # finish iterating through dataset 167 | txn.commit() 168 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 169 | with db.begin(write=True) as txn: 170 | txn.put(b'__keys__', dumps_pyarrow(keys)) 171 | txn.put(b'__len__', dumps_pyarrow(len(keys))) 172 | 173 | print("Flushing database ...") 174 | db.sync() 175 | db.close() 176 | 177 | 178 | if __name__ == "__main__": 179 | folder2lmdb("/ILSVRC2012") -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, "transforms") and self.transforms is not None: 41 | body += [repr(self.transforms)] 42 | lines = [head] + [" " * self._repr_indent + line for line in body] 43 | return '\n'.join(lines) 44 | 45 | def _format_transform_repr(self, transform, head): 46 | lines = transform.__repr__().splitlines() 47 | return (["{}{}".format(head, lines[0])] + 48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 49 | 50 | def extra_repr(self): 51 | return "" 52 | 53 | 54 | class StandardTransform(object): 55 | def __init__(self, transform=None, target_transform=None): 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | 59 | def __call__(self, input, target): 60 | if self.transform is not None: 61 | input = self.transform(input) 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | return input, target 65 | 66 | def _format_transform_repr(self, transform, head): 67 | lines = transform.__repr__().splitlines() 68 | return (["{}{}".format(head, lines[0])] + 69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 70 | 71 | def __repr__(self): 72 | body = [self.__class__.__name__] 73 | if self.transform is not None: 74 | body += self._format_transform_repr(self.transform, 75 | "Transform: ") 76 | if self.target_transform is not None: 77 | body += self._format_transform_repr(self.target_transform, 78 | "Target transform: ") 79 | 80 | return '\n'.join(body) 81 | -------------------------------------------------------------------------------- /dct/__init__.py: -------------------------------------------------------------------------------- 1 | subset_channel_index_square = { 2 | 1: 3 | [[0],[],[]], 4 | 5 | 6: 6 | [ 7 | [0,1, 8 | 8,9], 9 | [0], 10 | [0] 11 | ], 12 | 13 | 12: 14 | [ 15 | [0, 1, 2, 16 | 8, 9, 10, 17 | 16, 17], 18 | [0, 1], 19 | [0, 1] 20 | ], 21 | 22 | 24: 23 | [ 24 | [0, 1, 2, 3, 25 | 8, 9, 10, 11, 26 | 16, 17, 18, 19, 27 | 24, 25, 26, 27], 28 | [0, 1, 29 | 8, 9], 30 | [0, 1, 31 | 8, 9] 32 | ], 33 | 34 | 32: 35 | [ 36 | [0, 1, 2, 3, 4, 37 | 8, 9, 10, 11, 12, 38 | 16, 17, 18, 19, 20, 39 | 24, 25, 26, 27, 40 | 32, 33, 34], 41 | [0, 1, 2, 42 | 8, 9], 43 | [0, 1, 2, 44 | 8, 9] 45 | ], 46 | 47 | 48: 48 | [ 49 | [0, 1, 2, 3, 4, 5, 50 | 8, 9, 10, 11, 12, 13, 51 | 16, 17, 18, 19, 20, 21, 52 | 24, 25, 26, 27, 28, 29, 53 | 32, 33, 34, 35, 54 | 40, 41, 42, 43], 55 | [0, 1, 2, 56 | 8, 9, 10, 57 | 16, 17], 58 | [0, 1, 2, 59 | 8, 9, 10, 60 | 16, 17] 61 | ], 62 | 63 | 64: 64 | [ 65 | [0, 1, 2, 3, 4, 5, 6, 66 | 8, 9, 10, 11, 12, 13, 14, 67 | 16, 17, 18, 19, 20, 21, 68 | 24, 25, 26, 27, 28, 29, 69 | 32, 33, 34, 35, 36, 37, 70 | 40, 41, 42, 43, 44, 45, 71 | 48, 49, 50, 51, 52, 53], 72 | [0, 1, 2, 73 | 8, 9, 10, 74 | 16, 17, 75 | 24, 25], 76 | [0, 1, 2, 77 | 8, 9, 10, 78 | 16, 17, 79 | 24, 25], 80 | ] 81 | } 82 | 83 | subset_channel_index_learned = { 84 | 1: 85 | [[0], [], []], 86 | 87 | 88 | 89 | 24: 90 | [ 91 | [0, 1, 2, 3, 4, 5, 92 | 8, 9, 10, 93 | 16, 17, 18, 94 | 24, 95 | 32], 96 | [0, 1, 3, 97 | 8, 98 | 24], 99 | [0, 1, 3, 100 | 8, 101 | 24] 102 | ] 103 | } 104 | 105 | subset_channel_index_triangle = { 106 | 1: 107 | [[0], [], []], 108 | 109 | 12: 110 | [ 111 | [0, 1, 2, 112 | 8, 9, 113 | 16], 114 | [0, 1, 115 | 8], 116 | [0, 1, 117 | 8] 118 | ], 119 | 120 | 24: 121 | [ 122 | [0, 1, 2, 3, 4, 123 | 8, 9, 10, 11, 124 | 16, 17, 125 | 24,], 126 | [0, 1, 2, 127 | 8, 9, 128 | 16], 129 | [0, 1, 2, 130 | 8, 9, 131 | 16] 132 | ], 133 | 134 | 48: 135 | [ 136 | [0, 1, 2, 3, 4, 5, 6, 137 | 8, 9, 10, 11, 12, 13, 138 | 16, 17, 18, 19, 20, 139 | 24, 25, 26, 27, 140 | 32, 33, 34, 141 | 40, 41, 142 | 48], 143 | [0, 1, 2, 3, 144 | 8, 9, 10, 145 | 16, 17, 146 | 24], 147 | [0, 1, 2, 3, 148 | 8, 9, 10, 149 | 16, 17, 150 | 24] 151 | ], 152 | 153 | 64: 154 | [ 155 | [0, 1, 2, 3, 4, 5, 6, 7, 156 | 8, 9, 10, 11, 12, 13, 14, 157 | 16, 17, 18, 19, 20, 21, 158 | 24, 25, 26, 27, 28, 159 | 32, 33, 34, 35, 160 | 40, 41, 42, 161 | 48, 162 | ], 163 | [0, 1, 2, 3, 4, 164 | 8, 9, 10, 11, 165 | 16, 17, 18, 166 | 24, 25, 167 | 32], 168 | [0, 1, 2, 3, 4, 169 | 8, 9, 10, 11, 170 | 16, 17, 18, 171 | 24, 25, 172 | 32], 173 | ] 174 | } 175 | -------------------------------------------------------------------------------- /dct/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # from .resnext import * 4 | from .resnet import * 5 | # from .resnet_autosubset_inputgate import resnet50_autosubset_inputgate 6 | # from .resnext_attention import * 7 | # from .mobilenetv2_autosubset_alllayer import mobilenetv2dct_autosubset_alllayers 8 | 9 | -------------------------------------------------------------------------------- /dct/imagenet/gate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dct.imagenet.gumbel import GumbleSoftmax 4 | 5 | class GateModule(nn.Module): 6 | def __init__(self, in_ch, kernel_size=28, doubleGate=False, dwLA=False): 7 | super(GateModule, self).__init__() 8 | 9 | self.doubleGate, self.dwLA = doubleGate, dwLA 10 | self.inp_gs = GumbleSoftmax() 11 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 12 | self.in_ch = in_ch 13 | 14 | if dwLA: 15 | if doubleGate: 16 | self.inp_att = nn.Sequential( 17 | nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, stride=1, padding=0, groups=in_ch, 18 | bias=True), 19 | nn.BatchNorm2d(in_ch), 20 | nn.ReLU6(inplace=True), 21 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True), 22 | nn.Sigmoid() 23 | ) 24 | 25 | self.inp_gate = nn.Sequential( 26 | nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, stride=1, padding=0, groups=in_ch, bias=True), 27 | nn.BatchNorm2d(in_ch), 28 | nn.ReLU6(inplace=True), 29 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True), 30 | nn.BatchNorm2d(in_ch), 31 | ) 32 | self.inp_gate_l = nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, stride=1, padding=0, groups=in_ch, 33 | bias=True) 34 | else: 35 | if doubleGate: 36 | reduction = 4 37 | self.inp_att = nn.Sequential( 38 | nn.Conv2d(in_ch, in_ch // reduction, kernel_size=1, stride=1, padding=0, bias=True), 39 | nn.ReLU6(inplace=True), 40 | nn.Conv2d(in_ch // reduction, in_ch, kernel_size=1, stride=1, padding=0, bias=True), 41 | nn.Sigmoid() 42 | ) 43 | 44 | self.inp_gate = nn.Sequential( 45 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True), 46 | nn.BatchNorm2d(in_ch), 47 | nn.ReLU6(inplace=True), 48 | ) 49 | self.inp_gate_l = nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, stride=1, padding=0, groups=in_ch, bias=True) 50 | 51 | def forward(self, y, cb, cr, temperature=1.): 52 | if self.doubleGate: 53 | if self.dwLA: 54 | hatten_d1 = self.inp_att(x) 55 | hatten_d2 = self.inp_gate(x) 56 | hatten_d2 = self.inp_gate_l(hatten_d2) 57 | else: 58 | hatten_y, hatten_cb, hatten_cr = self.avg_pool(y), self.avg_pool(cb), self.avg_pool(cr) 59 | hatten = torch.cat((hatten_y, hatten_cb, hatten_cr), dim=1) 60 | 61 | hatten_d1 = self.inp_att(hatten) 62 | hatten_d2 = self.inp_gate(hatten) 63 | hatten_d2 = self.inp_gate_l(hatten_d2) 64 | 65 | hatten_d2 = hatten_d2.reshape(hatten_d2.size(0), self.in_ch, 2, 1) 66 | hatten_d2 = self.inp_gs(hatten_d2, temp=temperature, force_hard=True) 67 | else: 68 | if self.dwLA: 69 | hatten_d2 = self.inp_gate(x) 70 | hatten_d2 = self.inp_gate_l(hatten_d2) 71 | else: 72 | hatten_y, hatten_cb, hatten_cr = self.avg_pool(y), self.avg_pool(cb), self.avg_pool(cr) 73 | hatten_d2 = torch.cat((hatten_y, hatten_cb, hatten_cr), dim=1) 74 | hatten_d2 = self.inp_gate(hatten_d2) 75 | hatten_d2 = self.inp_gate_l(hatten_d2) 76 | 77 | hatten_d2 = hatten_d2.reshape(hatten_d2.size(0), self.in_ch, 2, 1) 78 | hatten_d2 = self.inp_gs(hatten_d2, temp=temperature, force_hard=True) 79 | 80 | if self.doubleGate: 81 | x = x * hatten_d1 * hatten_d2[:, :, 1].unsqueeze(2) 82 | else: 83 | y = y * hatten_d2[:, :64, 1].unsqueeze(2) 84 | cb = cb * hatten_d2[:, 64:128, 1].unsqueeze(2) 85 | cr = cr * hatten_d2[:, 128:, 1].unsqueeze(2) 86 | 87 | return y, cb, cr, hatten_d2[:, :, 1] 88 | 89 | 90 | 91 | class GateModule192(nn.Module): 92 | def __init__(self, act='relu'): 93 | super(GateModule192, self).__init__() 94 | 95 | self.inp_gs = GumbleSoftmax() 96 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 97 | self.in_ch = in_ch = 192 98 | if act == 'relu': 99 | relu = nn.ReLU 100 | elif act == 'relu6': 101 | relu = nn.ReLU6 102 | else: raise NotImplementedError 103 | 104 | self.inp_gate = nn.Sequential( 105 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True), 106 | nn.BatchNorm2d(in_ch), 107 | relu(inplace=True), 108 | ) 109 | self.inp_gate_l = nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, stride=1, padding=0, groups=in_ch, bias=True) 110 | 111 | 112 | def forward(self, x, temperature=1.): 113 | hatten = self.avg_pool(x) 114 | hatten_d = self.inp_gate(hatten) 115 | hatten_d = self.inp_gate_l(hatten_d) 116 | hatten_d = hatten_d.reshape(hatten_d.size(0), self.in_ch, 2, 1) 117 | hatten_d = self.inp_gs(hatten_d, temp=temperature, force_hard=True) 118 | 119 | x = x * hatten_d[:, :, 1].unsqueeze(2) 120 | 121 | return x, hatten_d[:, :, 1] 122 | -------------------------------------------------------------------------------- /dct/imagenet/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | 5 | class GumbleSoftmax(torch.nn.Module): 6 | def __init__(self, hard=False): 7 | super(GumbleSoftmax, self).__init__() 8 | self.hard = hard 9 | 10 | def sample_gumbel(self, shape, eps=1e-10): 11 | """Sample from Gumbel(0, 1)""" 12 | noise = torch.rand(shape) 13 | noise.add_(eps).log_().neg_() 14 | noise.add_(eps).log_().neg_() 15 | if self.gpu: 16 | return noise.cuda() 17 | else: 18 | return noise 19 | 20 | def sample_gumbel_like(self, template_tensor, eps=1e-10): 21 | uniform_samples_tensor = template_tensor.clone().uniform_() 22 | gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps)) 23 | return gumble_samples_tensor 24 | 25 | def gumbel_softmax_sample(self, logits, temperature): 26 | """ Draw a sample from the Gumbel-Softmax distribution""" 27 | dim = logits.size(2) 28 | gumble_samples_tensor = self.sample_gumbel_like(logits.data) 29 | gumble_trick_log_prob_samples = logits + gumble_samples_tensor 30 | soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, dim) 31 | return soft_samples 32 | 33 | def gumbel_softmax(self, logits, temperature, hard=False): 34 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 35 | Args: 36 | logits: [batch_size, n_class] unnormalized log-probslibaba 37 | temperature: non-negative scalar 38 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 39 | Returns: 40 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 41 | If hard=True, then the returned sample will be one-hot, otherwise it will 42 | be a probabilitiy distribution that sums to 1 across classes 43 | """ 44 | y = self.gumbel_softmax_sample(logits, temperature) 45 | if hard: 46 | # block layer 47 | # _, max_value_indexes = y.data.max(1, keepdim=True) 48 | # y_hard = logits.data.clone().zero_().scatter_(1, max_value_indexes, 1) 49 | # block channel 50 | _, max_value_indexes = y.data.max(2, keepdim=True) 51 | y_hard = logits.data.clone().zero_().scatter_(2, max_value_indexes, 1) 52 | y = Variable(y_hard - y.data) + y 53 | return y 54 | 55 | def forward(self, logits, temp=1, force_hard=False): 56 | samplesize = logits.size() 57 | 58 | if self.training and not force_hard: 59 | return self.gumbel_softmax(logits, temperature=1, hard=False) 60 | else: 61 | return self.gumbel_softmax(logits, temperature=1, hard=True) -------------------------------------------------------------------------------- /dct/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | def constant_init(module, val, bias=0): 6 | nn.init.constant_(module.weight, val) 7 | if hasattr(module, 'bias') and module.bias is not None: 8 | nn.init.constant_(module.bias, bias) 9 | 10 | 11 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 12 | assert distribution in ['uniform', 'normal'] 13 | if distribution == 'uniform': 14 | nn.init.xavier_uniform_(module.weight, gain=gain) 15 | else: 16 | nn.init.xavier_normal_(module.weight, gain=gain) 17 | if hasattr(module, 'bias') and module.bias is not None: 18 | nn.init.constant_(module.bias, bias) 19 | 20 | 21 | def normal_init(module, mean=0, std=1, bias=0): 22 | nn.init.normal_(module.weight, mean, std) 23 | if hasattr(module, 'bias') and module.bias is not None: 24 | nn.init.constant_(module.bias, bias) 25 | 26 | 27 | def uniform_init(module, a=0, b=1, bias=0): 28 | nn.init.uniform_(module.weight, a, b) 29 | if hasattr(module, 'bias') and module.bias is not None: 30 | nn.init.constant_(module.bias, bias) 31 | 32 | 33 | def kaiming_init(module, 34 | a=0, 35 | mode='fan_out', 36 | nonlinearity='relu', 37 | bias=0, 38 | distribution='normal'): 39 | assert distribution in ['uniform', 'normal'] 40 | if distribution == 'uniform': 41 | nn.init.kaiming_uniform_( 42 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 43 | else: 44 | nn.init.kaiming_normal_( 45 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 46 | if hasattr(module, 'bias') and module.bias is not None: 47 | nn.init.constant_(module.bias, bias) 48 | 49 | 50 | def caffe2_xavier_init(module, bias=0): 51 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch 52 | # Acknowledgment to FAIR's internal code 53 | kaiming_init( 54 | module, 55 | a=1, 56 | mode='fan_in', 57 | nonlinearity='leaky_relu', 58 | distribution='uniform') 59 | 60 | 61 | def get_upsample_filter(size): 62 | """Make a 2D bilinear kernel suitable for upsampling""" 63 | factor = (size + 1) // 2 64 | if size % 2 == 1: 65 | center = factor - 1 66 | else: 67 | center = factor - 0.5 68 | og = np.ogrid[:size, :size] 69 | filter = (1 - abs(og[0] - center) / factor) * \ 70 | (1 - abs(og[1] - center) / factor) 71 | return torch.from_numpy(filter).float() 72 | -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import cv2 6 | from facenet_pytorch.models.mtcnn import MTCNN 7 | from sklearn.metrics import accuracy_score 8 | from models.model import Baseline 9 | 10 | 11 | def load_model(restore_from, device): 12 | model = Baseline(use_gru=True, bi_branch=True) 13 | 14 | model.to(device) 15 | 16 | device_count = torch.cuda.device_count() 17 | # if device_count > 1: 18 | # print('Using {} GPUs'.format(device_count)) 19 | model = nn.DataParallel(model) 20 | 21 | if restore_from is not None: 22 | ckpt = torch.load(restore_from, map_location='cpu') 23 | model.load_state_dict(ckpt['model_state_dict']) 24 | print('Model is loaded from %s' % restore_from) 25 | 26 | model.eval() 27 | 28 | return model 29 | 30 | def _bbox_in_img(img, bbox): 31 | """ 32 | check whether the bbox is inner an image. 33 | :param img: (3-d np.ndarray), image 34 | :param bbox: (list) [x, y, width, height] 35 | :return: (bool), whether bbox in image size. 36 | """ 37 | if not isinstance(img, np.ndarray): 38 | raise ValueError("input image should be ndarray!") 39 | if len(img.shape) != 3: 40 | raise ValueError("input image should be (w,h,c)!") 41 | h = img.shape[0] 42 | w = img.shape[1] 43 | x_in = 0 <= bbox[0] <= w 44 | y_in = 0 <= bbox[1] <= h 45 | x1_in = 0 <= bbox[0] + bbox[2] <= w 46 | y1_in = 0 <= bbox[1] + bbox[3] <= h 47 | return x_in and y_in and x1_in and y1_in 48 | 49 | 50 | def _enlarged_bbox(bbox, expand): 51 | """ 52 | enlarge a bbox by given expand param. 53 | :param bbox: [x, y, width, height] 54 | :param expand: (tuple) (h,w), expanded pixels in height and width. if (int), same value in both side. 55 | :return: enlarged bbox 56 | """ 57 | if isinstance(expand, int): 58 | expand = (expand, expand) 59 | s_0, s_1 = bbox[1], bbox[0] 60 | e_0, e_1 = bbox[1] + bbox[3], bbox[0] + bbox[2] 61 | x = s_1 - expand[1] 62 | y = s_0 - expand[0] 63 | x1 = e_1 + expand[1] 64 | y1 = e_0 + expand[0] 65 | width = x1 - x 66 | height = y1 - y 67 | return x, y, width, height 68 | 69 | 70 | def _box_mode_cvt(bbox): 71 | """ 72 | convert box from FCOS([xyxy], float) output to [x, y, width, height](int). 73 | :param bbox: (dict), an output from FCOS([x, y, x1, y1], float). 74 | :return: (list[int]), a box with [x, y, width, height] format. 75 | """ 76 | if bbox is None: 77 | raise ValueError("There is no box in the dict!") 78 | # FCOS box format is [x, y, x1, y1] 79 | w = bbox[2] - bbox[0] 80 | h = bbox[3] - bbox[1] 81 | cvt_box = [int(bbox[0]), int(bbox[1]), max(int(w), 0), max(int(h), 0)] 82 | return cvt_box 83 | 84 | 85 | def crop_bbox(img, bbox): 86 | """ 87 | crop an image by giving exact bbox. 88 | :param img: 89 | :param bbox: [x, y, width, height] 90 | :return: cropped image 91 | """ 92 | if not _bbox_in_img(img, bbox): 93 | raise ValueError("bbox is out of image size!img size: {0}, bbox size: {1}".format(img.shape, bbox)) 94 | s_0 = bbox[1] 95 | s_1 = bbox[0] 96 | e_0 = bbox[1] + bbox[3] 97 | e_1 = bbox[0] + bbox[2] 98 | cropped_img = img[s_0:e_0, s_1:e_1, :] 99 | return cropped_img 100 | 101 | def face_boxes_post_process(img, box, expand_ratio): 102 | """ 103 | enlarge and crop the face patch from image 104 | :param img: ndarray, 1 frame from video 105 | :param box: output of MTCNN 106 | :param expand_ratio: default: 1.3 107 | :return: 108 | """ 109 | box = [max(b, 0) for b in box] 110 | box_xywh = _box_mode_cvt(box) 111 | expand_w = int((box_xywh[2] * (expand_ratio - 1)) / 2) 112 | expand_h = int((box_xywh[3] * (expand_ratio - 1)) / 2) 113 | enlarged_box = _enlarged_bbox(box_xywh, (expand_h, expand_w)) 114 | try: 115 | res = crop_bbox(img, enlarged_box) 116 | except ValueError: 117 | try: 118 | res = crop_bbox(img, box_xywh) 119 | except ValueError: 120 | return img 121 | return res 122 | 123 | def detect_face(frame, face_detector): 124 | boxes, _ = face_detector.detect(frame) 125 | if boxes is not None: 126 | best_box = boxes[0, :] 127 | best_face = face_boxes_post_process(frame, best_box, expand_ratio=1.33) 128 | return best_face 129 | else: 130 | return None 131 | 132 | 133 | def load_data(path, device): 134 | transform = transforms.Compose([ 135 | transforms.ToTensor(), 136 | ]) 137 | face_detector = MTCNN(margin=0, keep_all=False, select_largest=False, thresholds=[0.6, 0.7, 0.7], 138 | min_face_size=60, factor=0.8, device=device).eval() 139 | video_fd = cv2.VideoCapture(path) 140 | if not video_fd.isOpened(): 141 | print('problem of reading video') 142 | return 143 | 144 | frame_index = 0 145 | faces = [] 146 | success, frame = video_fd.read() 147 | while success: 148 | cropped_face = detect_face(frame, face_detector) 149 | cropped_face = cv2.resize(cropped_face, (64, 64)) 150 | if cropped_face is not None: 151 | cropped_face = transform(cropped_face) 152 | faces.append(cropped_face) 153 | frame_index += 1 154 | success, frame = video_fd.read() 155 | video_fd.release() 156 | print('video frame length:', frame_index) 157 | faces = torch.stack(faces, dim=0) 158 | faces = torch.unsqueeze(faces, 0) 159 | y = torch.ones(frame_index).type(torch.IntTensor) 160 | return faces, y 161 | 162 | 163 | def main(args): 164 | frame_y_gd = [] 165 | y_pred = [] 166 | frame_y_pred = [] 167 | use_cuda = torch.cuda.is_available() 168 | device = torch.device('cuda' if use_cuda else 'cpu') 169 | model = load_model(args.restore_from, device) 170 | data, y = load_data(args.path, device) 171 | X = data.to(device) 172 | y_, cnn_y = model(X) 173 | y_ = torch.sigmoid(y_) 174 | frame_y_ = torch.sigmoid(cnn_y) 175 | frame_y_gd += y.detach().numpy().tolist() 176 | frame_y_pred += frame_y_.detach().numpy().tolist() 177 | frame_y_pred = torch.tensor(frame_y_pred) 178 | frame_y_pred = [0 if i < 0.5 else 1 for i in frame_y_pred] 179 | test_frame_acc = accuracy_score(frame_y_gd, frame_y_pred) 180 | print('video is fake:', (y_ >= 0.5).item()) 181 | print('frame level acc:', test_frame_acc) 182 | 183 | 184 | if __name__ == '__main__': 185 | import argparse 186 | 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('--restore_from', type=str, default='./bi-model_type-baseline_gru_auc_0.150000_ep-10.pth') 189 | parser.add_argument('--path', type=str, default='./video/id0_id1_0002.mp4') 190 | args = parser.parse_args() 191 | main(args) 192 | -------------------------------------------------------------------------------- /demo/video/id0_id1_0002.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/demo/video/id0_id1_0002.mp4 -------------------------------------------------------------------------------- /fwa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/fwa/__init__.py -------------------------------------------------------------------------------- /fwa/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | # import torch.nn.functional as F 5 | import math 6 | 7 | 8 | class ResNet(nn.Module): 9 | def __init__(self, layers=18, num_class=2, pretrained=True): 10 | super(ResNet, self).__init__() 11 | if layers == 18: 12 | self.resnet = models.resnet18(pretrained=pretrained) 13 | elif layers == 34: 14 | self.resnet = models.resnet34(pretrained=pretrained) 15 | elif layers == 50: 16 | self.resnet = models.resnet50(pretrained=pretrained) 17 | elif layers == 101: 18 | self.resnet = models.resnet101(pretrained=pretrained) 19 | elif layers == 152: 20 | self.resnet = models.resnet152(pretrained=pretrained) 21 | else: 22 | raise ValueError('layers should be 18, 34, 50, 101.') 23 | self.num_class = num_class 24 | if layers in [18, 34]: 25 | self.fc = nn.Linear(512, num_class) 26 | if layers in [50, 101, 152]: 27 | self.fc = nn.Linear(512 * 4, num_class) 28 | 29 | def conv_base(self, x): 30 | x = self.resnet.conv1(x) 31 | x = self.resnet.bn1(x) 32 | x = self.resnet.relu(x) 33 | x = self.resnet.maxpool(x) 34 | 35 | layer1 = self.resnet.layer1(x) 36 | layer2 = self.resnet.layer2(layer1) 37 | layer3 = self.resnet.layer3(layer2) 38 | layer4 = self.resnet.layer4(layer3) 39 | return layer1, layer2, layer3, layer4 40 | 41 | def forward(self, x): 42 | layer1, layer2, layer3, layer4 = self.conv_base(x) 43 | x = self.resnet.avgpool(layer4) 44 | x = x.view(x.size(0), -1) 45 | x = self.fc(x) 46 | return x 47 | 48 | 49 | class SPPNet(nn.Module): 50 | def __init__(self, backbone=101, num_class=2, pool_size=(1, 2, 6), pretrained=True): 51 | # Only resnet is supported in this version 52 | super(SPPNet, self).__init__() 53 | if backbone in [18, 34, 50, 101, 152]: 54 | self.resnet = ResNet(backbone, num_class, pretrained) 55 | else: 56 | raise ValueError('Resnet{} is not supported yet.'.format(backbone)) 57 | 58 | if backbone in [18, 34]: 59 | self.c = 512 60 | if backbone in [50, 101, 152]: 61 | self.c = 2048 62 | 63 | self.spp = SpatialPyramidPool2D(out_side=pool_size) 64 | num_features = self.c * (pool_size[0] ** 2 + pool_size[1] ** 2 + pool_size[2] ** 2) 65 | self.classifier = nn.Linear(num_features, num_class) 66 | 67 | def forward(self, x): 68 | _, _, _, x = self.resnet.conv_base(x) 69 | x = self.spp(x) 70 | x = self.classifier(x) 71 | return x 72 | 73 | 74 | class SpatialPyramidPool2D(nn.Module): 75 | """ 76 | Args: 77 | out_side (tuple): Length of side in the pooling results of each pyramid layer. 78 | 79 | Inputs: 80 | - `input`: the input Tensor to invert ([batch, channel, width, height]) 81 | """ 82 | 83 | def __init__(self, out_side): 84 | super(SpatialPyramidPool2D, self).__init__() 85 | self.out_side = out_side 86 | 87 | def forward(self, x): 88 | # batch_size, c, h, w = x.size() 89 | out = None 90 | for n in self.out_side: 91 | w_r, h_r = map(lambda s: math.ceil(s / n), x.size()[2:]) # Receptive Field Size 92 | s_w, s_h = map(lambda s: math.floor(s / n), x.size()[2:]) # Stride 93 | max_pool = nn.MaxPool2d(kernel_size=(w_r, h_r), stride=(s_w, s_h)) 94 | y = max_pool(x) 95 | if out is None: 96 | out = y.view(y.size()[0], -1) 97 | else: 98 | out = torch.cat((out, y.view(y.size()[0], -1)), 1) 99 | return out 100 | -------------------------------------------------------------------------------- /imgs/imbalanced performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/imgs/imbalanced performance.png -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/imgs/overview.png -------------------------------------------------------------------------------- /meso/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/meso/__init__.py -------------------------------------------------------------------------------- /meso/eval_meso.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 4 | 5 | import numpy as np 6 | from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, roc_curve 7 | 8 | from tqdm import tqdm 9 | 10 | from meso.meso import * 11 | 12 | validation_data_path = '' 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"] = '2, 3' 15 | 16 | img_width, img_height = 64, 64 17 | batch_size = 2000 18 | epochs = 20 19 | 20 | frame_y_gd = [] 21 | frame_y_pred = [] 22 | 23 | model = MesoInception4() 24 | # model = Meso4() 25 | model.load('') 26 | 27 | test_datagen = ImageDataGenerator(rescale=1. / 255) 28 | 29 | validation_generator = test_datagen.flow_from_directory( 30 | validation_data_path, 31 | target_size=(img_height, img_width), 32 | batch_size=batch_size, 33 | class_mode='binary') 34 | 35 | i = 0 36 | for X, y in tqdm(validation_generator, desc='Validating'): 37 | y_ = model.predict(X) 38 | frame_y_pred += y_.tolist() 39 | frame_y_gd += y.tolist() 40 | i += 1 41 | if i >= 37: 42 | break 43 | 44 | gd = np.array(frame_y_gd) 45 | pred = np.array(frame_y_pred) 46 | pred_pro = pred 47 | 48 | pred = np.rint(pred) 49 | f_fpr, f_tpr, _ = roc_curve(gd, pred_pro) 50 | test_frame_acc = accuracy_score(gd, pred) 51 | test_frame_auc = roc_auc_score(gd, pred_pro) 52 | test_frame_f1 = f1_score(gd, pred) 53 | test_frame_pre = precision_score(gd, pred) 54 | test_frame_recall = recall_score(gd, pred) 55 | 56 | np.save('', f_fpr) 57 | np.save('', f_tpr) 58 | 59 | print('acc:, auc:, f1_score, precision_score, recall_score') 60 | print(test_frame_acc, test_frame_auc, test_frame_f1, test_frame_pre, test_frame_recall) 61 | -------------------------------------------------------------------------------- /meso/meso.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import tensorflow as tf 3 | from tensorflow.keras import backend as K 4 | from tensorflow.keras.models import Model as KerasModel 5 | from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Reshape, \ 6 | Concatenate, LeakyReLU 7 | from tensorflow.keras.optimizers import Adam 8 | 9 | from sklearn.metrics import roc_auc_score 10 | 11 | IMGWIDTH = 64 12 | 13 | 14 | def getPrecision(y_true, y_pred): 15 | TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) # TP 16 | N = (-1) * K.sum(K.round(K.clip(y_true - K.ones_like(y_true), -1, 0))) # N 17 | TN = K.sum(K.round(K.clip((y_true - K.ones_like(y_true)) * (y_pred - K.ones_like(y_pred)), 0, 1))) # TN 18 | FP = N - TN 19 | precision = TP / (TP + FP + K.epsilon()) # TT/P 20 | return precision 21 | 22 | 23 | def auroc(y_true, y_pred): 24 | return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double) 25 | 26 | 27 | def f1(y_true, y_pred): 28 | def recall(y_true, y_pred): 29 | """Recall metric. 30 | 31 | Only computes a batch-wise average of recall. 32 | 33 | Computes the recall, a metric for multi-label classification of 34 | how many relevant items are selected. 35 | """ 36 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 37 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 38 | recall = true_positives / (possible_positives + K.epsilon()) 39 | return recall 40 | 41 | def precision(y_true, y_pred): 42 | """Precision metric. 43 | 44 | Only computes a batch-wise average of precision. 45 | 46 | Computes the precision, a metric for multi-label classification of 47 | how many selected items are relevant. 48 | """ 49 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 50 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 51 | precision = true_positives / (predicted_positives + K.epsilon()) 52 | return precision 53 | 54 | precision = precision(y_true, y_pred) 55 | recall = recall(y_true, y_pred) 56 | return 2 * ((precision * recall) / (precision + recall + K.epsilon())) 57 | 58 | 59 | class Classifier: 60 | def __init__(self): 61 | self.model = 0 62 | 63 | def predict(self, x): 64 | return self.model.predict(x) 65 | 66 | def fit(self, x, y): 67 | return self.model.train_on_batch(x, y) 68 | 69 | def get_accuracy(self, x, y): 70 | return self.model.test_on_batch(x, y) 71 | 72 | def get_auc(self, x, y): 73 | return auroc(x, y) 74 | 75 | def load(self, path): 76 | self.model.load_weights(path) 77 | 78 | 79 | class Meso1(Classifier): 80 | """ 81 | Feature extraction + Classification 82 | """ 83 | 84 | def __init__(self, learning_rate=1e-4, dl_rate=1): 85 | self.model = self.init_model(dl_rate) 86 | optimizer = Adam(lr=learning_rate) 87 | self.model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', auroc]) 88 | 89 | def init_model(self, dl_rate): 90 | x = Input(shape=(IMGWIDTH, IMGWIDTH, 3)) 91 | 92 | x1 = Conv2D(16, (3, 3), dilation_rate=dl_rate, strides=1, padding='same', activation='relu')(x) 93 | x1 = Conv2D(4, (1, 1), padding='same', activation='relu')(x1) 94 | x1 = BatchNormalization()(x1) 95 | x1 = MaxPooling2D(pool_size=(8, 8), padding='same')(x1) 96 | 97 | y = Flatten()(x1) 98 | y = Dropout(0.5)(y) 99 | y = Dense(1, activation='sigmoid')(y) 100 | return KerasModel(inputs=x, outputs=y) 101 | 102 | 103 | class Meso4(Classifier): 104 | def __init__(self, learning_rate=1e-5): 105 | self.model = self.init_model() 106 | optimizer = Adam(lr=learning_rate) 107 | self.model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', 'AUC', f1, 108 | 'Recall']) 109 | 110 | def init_model(self): 111 | x = Input(shape=(IMGWIDTH, IMGWIDTH, 3)) 112 | 113 | x1 = Conv2D(8, (3, 3), padding='same', activation='relu')(x) 114 | x1 = BatchNormalization()(x1) 115 | x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1) 116 | 117 | x2 = Conv2D(8, (5, 5), padding='same', activation='relu')(x1) 118 | x2 = BatchNormalization()(x2) 119 | x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2) 120 | 121 | x3 = Conv2D(16, (5, 5), padding='same', activation='relu')(x2) 122 | x3 = BatchNormalization()(x3) 123 | x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3) 124 | 125 | x4 = Conv2D(16, (5, 5), padding='same', activation='relu')(x3) 126 | x4 = BatchNormalization()(x4) 127 | x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4) 128 | 129 | y = Flatten()(x4) 130 | y = Dropout(0.5)(y) 131 | y = Dense(16)(y) 132 | y = LeakyReLU(alpha=0.1)(y) 133 | y = Dropout(0.5)(y) 134 | y = Dense(1, activation='sigmoid')(y) 135 | 136 | return KerasModel(inputs=x, outputs=y) 137 | 138 | 139 | class MesoInception4(Classifier): 140 | def __init__(self, learning_rate=0.001): 141 | self.model = self.init_model() 142 | optimizer = Adam(lr=learning_rate) 143 | self.model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=['accuracy', 'AUC', f1, 144 | 'Recall']) 145 | 146 | def InceptionLayer(self, a, b, c, d): 147 | def func(x): 148 | x1 = Conv2D(a, (1, 1), padding='same', activation='relu')(x) 149 | 150 | x2 = Conv2D(b, (1, 1), padding='same', activation='relu')(x) 151 | x2 = Conv2D(b, (3, 3), padding='same', activation='relu')(x2) 152 | 153 | x3 = Conv2D(c, (1, 1), padding='same', activation='relu')(x) 154 | x3 = Conv2D(c, (3, 3), dilation_rate=2, strides=1, padding='same', activation='relu')(x3) 155 | 156 | x4 = Conv2D(d, (1, 1), padding='same', activation='relu')(x) 157 | x4 = Conv2D(d, (3, 3), dilation_rate=3, strides=1, padding='same', activation='relu')(x4) 158 | 159 | y = Concatenate(axis=-1)([x1, x2, x3, x4]) 160 | 161 | return y 162 | 163 | return func 164 | 165 | def init_model(self): 166 | x = Input(shape=(IMGWIDTH, IMGWIDTH, 3)) 167 | 168 | x1 = self.InceptionLayer(1, 4, 4, 2)(x) 169 | x1 = BatchNormalization()(x1) 170 | x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1) 171 | 172 | x2 = self.InceptionLayer(2, 4, 4, 2)(x1) 173 | x2 = BatchNormalization()(x2) 174 | x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2) 175 | 176 | x3 = Conv2D(16, (5, 5), padding='same', activation='relu')(x2) 177 | x3 = BatchNormalization()(x3) 178 | x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3) 179 | 180 | x4 = Conv2D(16, (5, 5), padding='same', activation='relu')(x3) 181 | x4 = BatchNormalization()(x4) 182 | x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4) 183 | 184 | y = Flatten()(x4) 185 | y = Dropout(0.5)(y) 186 | y = Dense(16)(y) 187 | y = LeakyReLU(alpha=0.1)(y) 188 | y = Dropout(0.5)(y) 189 | y = Dense(1, activation='sigmoid')(y) 190 | 191 | return KerasModel(inputs=x, outputs=y) 192 | -------------------------------------------------------------------------------- /meso/train_mesonet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 4 | from tensorflow.keras import callbacks 5 | 6 | import time 7 | 8 | from meso.meso import * 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3' 11 | start = time.time() 12 | 13 | img_width, img_height = 64, 64 14 | batch_size = 2000 15 | epochs = 20 16 | 17 | train_data_path = '' 18 | validation_data_path = '' 19 | 20 | # model = Meso4().model 21 | model = MesoInception4().model 22 | 23 | train_datagen = ImageDataGenerator(rescale=1. / 255) 24 | 25 | test_datagen = ImageDataGenerator(rescale=1. / 255) 26 | 27 | train_generator = train_datagen.flow_from_directory( 28 | train_data_path, 29 | target_size=(img_height, img_width), 30 | batch_size=batch_size, 31 | class_mode='binary') 32 | 33 | validation_generator = test_datagen.flow_from_directory( 34 | validation_data_path, 35 | target_size=(img_height, img_width), 36 | batch_size=batch_size, 37 | class_mode='binary') 38 | 39 | log_dir = './tf-log/' 40 | tb_cb = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0) 41 | cbks = [tb_cb] 42 | 43 | model.fit_generator( 44 | train_generator, 45 | epochs=epochs, 46 | validation_data=validation_generator, 47 | callbacks=cbks, 48 | shuffle=True) 49 | 50 | target_dir = './meso/' 51 | if not os.path.exists(target_dir): 52 | os.mkdir(target_dir) 53 | model.save('./meso/model.h5') 54 | model.save_weights('./meso/weights.h5') 55 | 56 | # Calculate execution time 57 | end = time.time() 58 | dur = end - start 59 | 60 | if dur < 60: 61 | print("Execution Time:", dur, "seconds") 62 | elif dur > 60 and dur < 3600: 63 | dur = dur / 60 64 | print("Execution Time:", dur, "minutes") 65 | else: 66 | dur = dur / (60 * 60) 67 | print("Execution Time:", dur, "hours") 68 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/models/__init__.py -------------------------------------------------------------------------------- /models/convGRU.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class ConvGRUCell(nn.Module): 8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias, dtype): 9 | """ 10 | Initialize the ConvLSTM cell 11 | :param input_size: (int, int) 12 | Height and width of input tensor as (height, width). 13 | :param input_dim: int 14 | Number of channels of input tensor. 15 | :param hidden_dim: int 16 | Number of channels of hidden state. 17 | :param kernel_size: (int, int) 18 | Size of the convolutional kernel. 19 | :param bias: bool 20 | Whether or not to add the bias. 21 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor 22 | Whether or not to use cuda. 23 | """ 24 | super(ConvGRUCell, self).__init__() 25 | self.height, self.width = input_size 26 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 27 | self.hidden_dim = hidden_dim 28 | self.bias = bias 29 | self.dtype = dtype 30 | 31 | self.conv_gates = nn.Conv2d(in_channels=input_dim + hidden_dim, 32 | out_channels=2*self.hidden_dim, # for update_gate,reset_gate respectively 33 | kernel_size=kernel_size, 34 | padding=self.padding, 35 | bias=self.bias) 36 | 37 | self.conv_can = nn.Conv2d(in_channels=input_dim+hidden_dim, 38 | out_channels=self.hidden_dim, # for candidate neural memory 39 | kernel_size=kernel_size, 40 | padding=self.padding, 41 | bias=self.bias) 42 | 43 | def init_hidden(self, batch_size): 44 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).type(self.dtype)) 45 | 46 | def forward(self, input_tensor, h_cur): 47 | """ 48 | 49 | :param self: 50 | :param input_tensor: (b, c, h, w) 51 | input is actually the target_model 52 | :param h_cur: (b, c_hidden, h, w) 53 | current hidden and cell states respectively 54 | :return: h_next, 55 | next hidden state 56 | """ 57 | combined = torch.cat([input_tensor, h_cur], dim=1) 58 | combined_conv = self.conv_gates(combined) 59 | 60 | gamma, beta = torch.split(combined_conv, self.hidden_dim, dim=1) 61 | reset_gate = torch.sigmoid(gamma) 62 | update_gate = torch.sigmoid(beta) 63 | 64 | combined = torch.cat([input_tensor, reset_gate*h_cur], dim=1) 65 | cc_cnm = self.conv_can(combined) 66 | cnm = torch.tanh(cc_cnm) 67 | 68 | h_next = (1 - update_gate) * h_cur + update_gate * cnm 69 | return h_next 70 | 71 | 72 | class ConvGRU(nn.Module): 73 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 74 | dtype=torch.cuda.FloatTensor, batch_first=False, bias=True, return_all_layers=False): 75 | """ 76 | 77 | :param input_size: (int, int) 78 | Height and width of input tensor as (height, width). 79 | :param input_dim: int e.g. 256 80 | Number of channels of input tensor. 81 | :param hidden_dim: int e.g. 1024 82 | Number of channels of hidden state. 83 | :param kernel_size: (int, int) 84 | Size of the convolutional kernel. 85 | :param num_layers: int 86 | Number of ConvLSTM layers 87 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor 88 | Whether or not to use cuda. 89 | :param alexnet_path: str 90 | pretrained alexnet parameters 91 | :param batch_first: bool 92 | if the first position of array is batch or not 93 | :param bias: bool 94 | Whether or not to add the bias. 95 | :param return_all_layers: bool 96 | if return hidden and cell states for all layers 97 | """ 98 | super(ConvGRU, self).__init__() 99 | 100 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 101 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 102 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 103 | if not len(kernel_size) == len(hidden_dim) == num_layers: 104 | raise ValueError('Inconsistent list length.') 105 | 106 | self.height, self.width = input_size 107 | self.input_dim = input_dim 108 | self.hidden_dim = hidden_dim 109 | self.kernel_size = kernel_size 110 | self.dtype = dtype 111 | self.num_layers = num_layers 112 | self.batch_first = batch_first 113 | self.bias = bias 114 | self.return_all_layers = return_all_layers 115 | 116 | cell_list = [] 117 | for i in range(0, self.num_layers): 118 | cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1] 119 | cell_list.append(ConvGRUCell(input_size=(self.height, self.width), 120 | input_dim=cur_input_dim, 121 | hidden_dim=self.hidden_dim[i], 122 | kernel_size=self.kernel_size[i], 123 | bias=self.bias, 124 | dtype=self.dtype)) 125 | 126 | # convert python list to pytorch module 127 | self.cell_list = nn.ModuleList(cell_list) 128 | 129 | def forward(self, input_tensor, hidden_state=None): 130 | """ 131 | 132 | :param input_tensor: (b, t, c, h, w) or (t,b,c,h,w) depends on if batch first or not 133 | extracted features from alexnet 134 | :param hidden_state: 135 | :return: layer_output_list, last_state_list 136 | """ 137 | if not self.batch_first: 138 | # (t, b, c, h, w) -> (b, t, c, h, w) 139 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 140 | 141 | # Implement stateful ConvLSTM 142 | if hidden_state is not None: 143 | raise NotImplementedError() 144 | else: 145 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0)) 146 | 147 | layer_output_list = [] 148 | last_state_list = [] 149 | 150 | seq_len = input_tensor.size(1) 151 | cur_layer_input = input_tensor 152 | 153 | for layer_idx in range(self.num_layers): 154 | h = hidden_state[layer_idx] 155 | output_inner = [] 156 | for t in range(seq_len): 157 | # input current hidden and cell state then compute the next hidden and cell state through ConvLSTMCell forward function 158 | h = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], # (b,t,c,h,w) 159 | h_cur=h) 160 | output_inner.append(h) 161 | 162 | layer_output = torch.stack(output_inner, dim=1) 163 | cur_layer_input = layer_output 164 | 165 | layer_output_list.append(layer_output) 166 | last_state_list.append([h]) 167 | 168 | if not self.return_all_layers: 169 | layer_output_list = layer_output_list[-1:] 170 | last_state_list = last_state_list[-1:] 171 | 172 | return layer_output_list, last_state_list 173 | 174 | def _init_hidden(self, batch_size): 175 | init_states = [] 176 | for i in range(self.num_layers): 177 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 178 | return init_states 179 | 180 | @staticmethod 181 | def _check_kernel_size_consistency(kernel_size): 182 | if not (isinstance(kernel_size, tuple) or 183 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 184 | raise ValueError('`kernel_size` must be tuple or list of tuples') 185 | 186 | @staticmethod 187 | def _extend_for_multilayer(param, num_layers): 188 | if not isinstance(param, list): 189 | param = [param] * num_layers 190 | return param 191 | 192 | 193 | if __name__ == '__main__': 194 | # set CUDA device 195 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 196 | 197 | # detect if CUDA is available or not 198 | use_gpu = torch.cuda.is_available() 199 | if use_gpu: 200 | dtype = torch.cuda.FloatTensor # computation in GPU 201 | else: 202 | dtype = torch.FloatTensor 203 | 204 | height = width = 6 205 | channels = 256 206 | hidden_dim = [32, 64] 207 | kernel_size = (3,3) # kernel size for two stacked hidden layer 208 | num_layers = 2 # number of stacked hidden layer 209 | model = ConvGRU(input_size=(height, width), 210 | input_dim=channels, 211 | hidden_dim=hidden_dim, 212 | kernel_size=kernel_size, 213 | num_layers=num_layers, 214 | dtype=dtype, 215 | batch_first=True, 216 | bias = True, 217 | return_all_layers = False) 218 | 219 | batch_size = 1 220 | time_steps = 1 221 | input_tensor = torch.rand(batch_size, time_steps, channels, height, width) # (b,t,c,h,w) 222 | layer_output_list, last_state_list = model(input_tensor) -------------------------------------------------------------------------------- /models/convlstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ConvLSTMCell(nn.Module): 6 | 7 | def __init__(self, input_dim, hidden_dim, kernel_size, bias): 8 | """ 9 | Initialize ConvLSTM cell. 10 | 11 | Parameters 12 | ---------- 13 | input_dim: int 14 | Number of channels of input tensor. 15 | hidden_dim: int 16 | Number of channels of hidden state. 17 | kernel_size: (int, int) 18 | Size of the convolutional kernel. 19 | bias: bool 20 | Whether or not to add the bias. 21 | """ 22 | 23 | super(ConvLSTMCell, self).__init__() 24 | 25 | self.input_dim = input_dim 26 | self.hidden_dim = hidden_dim 27 | 28 | self.kernel_size = kernel_size 29 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 30 | self.bias = bias 31 | 32 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 33 | out_channels=4 * self.hidden_dim, 34 | kernel_size=self.kernel_size, 35 | padding=self.padding, 36 | bias=self.bias) 37 | 38 | def forward(self, input_tensor, cur_state): 39 | h_cur, c_cur = cur_state 40 | 41 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 42 | 43 | combined_conv = self.conv(combined) 44 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 45 | i = torch.sigmoid(cc_i) 46 | f = torch.sigmoid(cc_f) 47 | o = torch.sigmoid(cc_o) 48 | g = torch.tanh(cc_g) 49 | 50 | c_next = f * c_cur + i * g 51 | h_next = o * torch.tanh(c_next) 52 | 53 | return h_next, c_next 54 | 55 | def init_hidden(self, batch_size, image_size): 56 | height, width = image_size 57 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 58 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) 59 | 60 | 61 | class ConvLSTM(nn.Module): 62 | 63 | """ 64 | 65 | Parameters: 66 | input_dim: Number of channels in input 67 | hidden_dim: Number of hidden channels 68 | kernel_size: Size of kernel in convolutions 69 | num_layers: Number of LSTM layers stacked on each other 70 | batch_first: Whether or not dimension 0 is the batch or not 71 | bias: Bias or no bias in Convolution 72 | return_all_layers: Return the list of computations for all layers 73 | Note: Will do same padding. 74 | 75 | Input: 76 | A tensor of size B, T, C, H, W or T, B, C, H, W 77 | Output: 78 | A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). 79 | 0 - layer_output_list is the list of lists of length T of each output 80 | 1 - last_state_list is the list of last states 81 | each element of the list is a tuple (h, c) for hidden state and memory 82 | Example: 83 | >> x = torch.rand((32, 10, 64, 128, 128)) 84 | >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) 85 | >> _, last_states = convlstm(x) 86 | >> h = last_states[0][0] # 0 for layer index, 0 for h index 87 | """ 88 | 89 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, 90 | batch_first=False, bias=True, return_all_layers=False): 91 | super(ConvLSTM, self).__init__() 92 | 93 | self._check_kernel_size_consistency(kernel_size) 94 | 95 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 96 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 97 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 98 | if not len(kernel_size) == len(hidden_dim) == num_layers: 99 | raise ValueError('Inconsistent list length.') 100 | 101 | self.input_dim = input_dim 102 | self.hidden_dim = hidden_dim 103 | self.kernel_size = kernel_size 104 | self.num_layers = num_layers 105 | self.batch_first = batch_first 106 | self.bias = bias 107 | self.return_all_layers = return_all_layers 108 | 109 | cell_list = [] 110 | for i in range(0, self.num_layers): 111 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] 112 | 113 | cell_list.append(ConvLSTMCell(input_dim=cur_input_dim, 114 | hidden_dim=self.hidden_dim[i], 115 | kernel_size=self.kernel_size[i], 116 | bias=self.bias)) 117 | 118 | self.cell_list = nn.ModuleList(cell_list) 119 | 120 | def forward(self, input_tensor, hidden_state=None): 121 | """ 122 | 123 | Parameters 124 | ---------- 125 | input_tensor: todo 126 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 127 | hidden_state: todo 128 | None. todo implement stateful 129 | 130 | Returns 131 | ------- 132 | last_state_list, layer_output 133 | """ 134 | if not self.batch_first: 135 | # (t, b, c, h, w) -> (b, t, c, h, w) 136 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 137 | 138 | b, _, _, h, w = input_tensor.size() 139 | 140 | # Implement stateful ConvLSTM 141 | if hidden_state is not None: 142 | raise NotImplementedError() 143 | else: 144 | # Since the init is done in forward. Can send image size here 145 | hidden_state = self._init_hidden(batch_size=b, 146 | image_size=(h, w)) 147 | 148 | layer_output_list = [] 149 | last_state_list = [] 150 | 151 | seq_len = input_tensor.size(1) 152 | cur_layer_input = input_tensor 153 | 154 | for layer_idx in range(self.num_layers): 155 | 156 | h, c = hidden_state[layer_idx] 157 | output_inner = [] 158 | for t in range(seq_len): 159 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 160 | cur_state=[h, c]) 161 | output_inner.append(h) 162 | 163 | layer_output = torch.stack(output_inner, dim=1) 164 | cur_layer_input = layer_output 165 | 166 | layer_output_list.append(layer_output) 167 | last_state_list.append([h, c]) 168 | 169 | if not self.return_all_layers: 170 | layer_output_list = layer_output_list[-1:] 171 | last_state_list = last_state_list[-1:] 172 | 173 | return layer_output_list, last_state_list 174 | 175 | def _init_hidden(self, batch_size, image_size): 176 | init_states = [] 177 | for i in range(self.num_layers): 178 | init_states.append(self.cell_list[i].init_hidden(batch_size, image_size)) 179 | return init_states 180 | 181 | @staticmethod 182 | def _check_kernel_size_consistency(kernel_size): 183 | if not (isinstance(kernel_size, tuple) or 184 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 185 | raise ValueError('`kernel_size` must be tuple or list of tuples') 186 | 187 | @staticmethod 188 | def _extend_for_multilayer(param, num_layers): 189 | if not isinstance(param, list): 190 | param = [param] * num_layers 191 | return param 192 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torchvision import models as Models 5 | 6 | from dct.imagenet.gate import GateModule192 7 | from dct.utils import kaiming_init, constant_init 8 | from models.convlstm import ConvLSTM 9 | from models.convGRU import ConvGRU 10 | from models import resnet 11 | 12 | from dct.imagenet.resnet import * 13 | 14 | 15 | class Baseline(nn.Module): 16 | 17 | def __init__(self, use_gru=False, bi_branch=False, rnn_hidden_layers=3, rnn_hidden_nodes=256, 18 | num_classes=1, bidirectional=False, dct=False, inputgate=False): 19 | 20 | super(Baseline, self).__init__() 21 | 22 | self.rnn_hidden_layers = rnn_hidden_layers 23 | self.rnn_hidden_nodes = rnn_hidden_nodes 24 | self.num_classes = num_classes 25 | self.bi_branch = bi_branch 26 | self.inputgate = inputgate 27 | 28 | if not dct: 29 | pretrained_cnn = Models.resnet50(pretrained=True) 30 | cnn_layers = list(pretrained_cnn.children())[:-1] 31 | else: 32 | pretrained_cnn = ResNetDCT_Upscaled_Static(channels=192, pretrained=True) 33 | cnn_layers = list(pretrained_cnn.children())[:-2] 34 | 35 | self.cnn = nn.Sequential(*cnn_layers) 36 | rnn_params = { 37 | 'input_size': pretrained_cnn.fc.in_features, 38 | 'hidden_size': self.rnn_hidden_nodes, 39 | 'num_layers': self.rnn_hidden_layers, 40 | 'batch_first': True, 41 | 'bidirectional': bidirectional 42 | } 43 | 44 | if bidirectional: 45 | fc_in = 2 * rnn_hidden_nodes 46 | else: 47 | fc_in = rnn_hidden_nodes 48 | 49 | self.rnn = (nn.GRU if use_gru else nn.LSTM)(**rnn_params) 50 | 51 | self.fc_cnn = nn.Linear(fc_in, num_classes) 52 | 53 | self.global_pool = nn.AdaptiveAvgPool2d(16) 54 | 55 | self.fc_rnn = nn.Linear(256, self.num_classes) 56 | 57 | if inputgate: 58 | self.inp_GM = GateModule192() 59 | self._initialize_weights() 60 | 61 | def forward(self, x_3d): 62 | 63 | cnn_embedding_out = [] 64 | cnn_pred = [] 65 | frame_num = x_3d.size(1) 66 | gates = [] 67 | 68 | for t in range(frame_num): 69 | if self.inputgate: 70 | x, gate_activations = self.inp_GM(x_3d[:, t, :, :, :]) 71 | gates.append(gate_activations) 72 | x = self.cnn(x_3d[:, t, :, :, :]) 73 | x = torch.flatten(x, start_dim=1) 74 | cnn_embedding_out.append(x) 75 | 76 | cnn_embedding_out = torch.stack(cnn_embedding_out, dim=0).transpose(0, 1) 77 | 78 | self.rnn.flatten_parameters() 79 | rnn_out, _ = self.rnn(cnn_embedding_out, None) 80 | 81 | if self.bi_branch: 82 | for t in range(rnn_out.size(1)): 83 | x = rnn_out[:, t, :] 84 | x = self.fc_cnn(x) 85 | cnn_pred.append(x) 86 | cnn_pred = torch.stack(cnn_pred, dim=0).transpose(0, 1) 87 | 88 | x = self.global_pool(rnn_out) 89 | x = torch.flatten(x, start_dim=1) 90 | x = self.fc_rnn(x) 91 | 92 | if self.inputgate: 93 | if self.bi_branch: 94 | return x, cnn_pred.reshape(-1, self.num_classes), torch.stack(gates, dim=0).view(-1, 192, 1) 95 | else: 96 | return x, gates 97 | else: 98 | if self.bi_branch: 99 | return x, cnn_pred.reshape(-1, self.num_classes) 100 | else: 101 | return x 102 | 103 | def _initialize_weights(self): 104 | for name, m in self.named_modules(): 105 | if 'inp_gate_l' in str(name): 106 | m.weight.data.normal_(0, 0.001) 107 | m.bias.data[::2].fill_(0.1) 108 | m.bias.data[1::2].fill_(2) 109 | elif 'inp_gate' in str(name): 110 | if isinstance(m, nn.Conv2d): 111 | kaiming_init(m) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | constant_init(m, 1) 114 | 115 | 116 | class CNN(nn.Module): 117 | def __init__(self, bi_branch=False, num_classes=2): 118 | super(CNN, self).__init__() 119 | 120 | self.num_classes = num_classes 121 | 122 | # 使用resnet预训练模型来提取特征,去掉最后一层分类器 123 | pretrained_cnn = Models.resnet50(pretrained=True) 124 | cnn_layers = list(pretrained_cnn.children())[:-1] 125 | 126 | # 把resnet的最后一层fc层去掉,用来提取特征 127 | self.cnn = nn.Sequential(*cnn_layers) 128 | 129 | self.global_pool = nn.AdaptiveAvgPool1d(1) 130 | 131 | self.cnn_out = nn.Sequential( 132 | nn.Linear(2048, 2) 133 | ) 134 | 135 | def forward(self, x_3d): 136 | """ 137 | 输入的是T帧图像,shape = (batch_size, t, h, w, 3) 138 | """ 139 | cnn_embedding_out = [] 140 | for t in range(x_3d.size(1)): 141 | # 使用cnn提取特征 142 | x = self.cnn(x_3d[:, t, :, :, :]) 143 | x = torch.flatten(x, start_dim=1) 144 | x = self.cnn_out(x) 145 | cnn_embedding_out.append(x) 146 | cnn_embedding_out = torch.stack(cnn_embedding_out, dim=0).transpose(0, 1) 147 | 148 | x = self.global_pool(cnn_embedding_out) 149 | x = torch.flatten(x, start_dim=1) 150 | 151 | return x 152 | 153 | 154 | class cRNN(nn.Module): 155 | def __init__(self, use_gru=False, bi_branch=False, num_classes=2): 156 | super(cRNN, self).__init__() 157 | 158 | self.num_classes = num_classes 159 | self.use_gru = use_gru 160 | 161 | # 使用resnet预训练模型来提取特征,去掉最后一层分类器 162 | pretrained_cnn = Models.resnet50(pretrained=True) 163 | cnn_layers = list(pretrained_cnn.children())[:-2] 164 | 165 | # 把resnet的最后一层fc层去掉,用来提取特征 166 | self.cnn = nn.Sequential(*cnn_layers) 167 | 168 | cRNN_params = { 169 | 'input_dim': 2048, 170 | 'hidden_dim': [256, 256, 512], 171 | 'kernel_size': (1, 1), 172 | 'num_layers': 3, 173 | 'batch_first': True 174 | } if not use_gru else { 175 | 'input_size': (2, 2), 176 | 'input_dim': 2048, 177 | 'hidden_dim': [256, 256, 512], 178 | 'kernel_size': (1, 1), 179 | 'num_layers': 3, 180 | 'batch_first': True 181 | } 182 | 183 | self.cRNN = (ConvGRU if use_gru else ConvLSTM)(**cRNN_params) 184 | 185 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 186 | 187 | self.fc = nn.Sequential( 188 | nn.Linear(512, self.num_classes) 189 | ) 190 | 191 | def forward(self, x_3d): 192 | cnn_embedding_out = [] 193 | for t in range(x_3d.size(1)): 194 | # 使用cnn提取特征 195 | x = self.cnn(x_3d[:, t, :, :, :]) 196 | cnn_embedding_out.append(x) 197 | 198 | x = torch.stack(cnn_embedding_out, dim=0).transpose(0, 1) 199 | 200 | _, outputs = self.cRNN(x) 201 | x = outputs[0][0] if self.use_gru else outputs[0][1] 202 | 203 | x = self.global_pool(x) 204 | x = torch.flatten(x, 1) 205 | x = self.fc(x) 206 | 207 | return x 208 | 209 | 210 | def get_resnet_3d(num_classes=2, model_depth=10, shortcut_type='B', sample_size=112, sample_duration=16): 211 | assert model_depth in [10, 18, 34, 50, 101, 152, 200] 212 | 213 | if model_depth == 10: 214 | model = resnet.resnet10( 215 | num_classes=num_classes, 216 | shortcut_type=shortcut_type, 217 | sample_size=sample_size, 218 | sample_duration=sample_duration) 219 | elif model_depth == 18: 220 | model = resnet.resnet18( 221 | num_classes=num_classes, 222 | shortcut_type=shortcut_type, 223 | sample_size=sample_size, 224 | sample_duration=sample_duration) 225 | elif model_depth == 34: 226 | model = resnet.resnet34( 227 | num_classes=num_classes, 228 | shortcut_type=shortcut_type, 229 | sample_size=sample_size, 230 | sample_duration=sample_duration) 231 | elif model_depth == 50: 232 | model = resnet.resnet50( 233 | num_classes=num_classes, 234 | shortcut_type=shortcut_type, 235 | sample_size=sample_size, 236 | sample_duration=sample_duration) 237 | elif model_depth == 101: 238 | model = resnet.resnet101( 239 | num_classes=num_classes, 240 | shortcut_type=shortcut_type, 241 | sample_size=sample_size, 242 | sample_duration=sample_duration) 243 | elif model_depth == 152: 244 | model = resnet.resnet152( 245 | num_classes=num_classes, 246 | shortcut_type=shortcut_type, 247 | sample_size=sample_size, 248 | sample_duration=sample_duration) 249 | else: 250 | model = resnet.resnet200( 251 | num_classes=num_classes, 252 | shortcut_type=shortcut_type, 253 | sample_size=sample_size, 254 | sample_duration=sample_duration) 255 | 256 | return model 257 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet200' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=1, 22 | bias=False) 23 | 24 | 25 | def downsample_basic_block(x, planes, stride): 26 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 27 | zero_pads = torch.Tensor( 28 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 29 | out.size(4)).zero_() 30 | if isinstance(out.data, torch.cuda.FloatTensor): 31 | zero_pads = zero_pads.cuda() 32 | 33 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 34 | 35 | return out 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = conv3x3x3(inplanes, planes, stride) 44 | self.bn1 = nn.BatchNorm3d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3x3(planes, planes) 47 | self.bn2 = nn.BatchNorm3d(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | residual = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 76 | self.bn1 = nn.BatchNorm3d(planes) 77 | self.conv2 = nn.Conv3d( 78 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 79 | self.bn2 = nn.BatchNorm3d(planes) 80 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 81 | self.bn3 = nn.BatchNorm3d(planes * 4) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__(self, 112 | block, 113 | layers, 114 | sample_size, 115 | sample_duration, 116 | shortcut_type='B', 117 | num_classes=400): 118 | self.inplanes = 64 119 | super(ResNet, self).__init__() 120 | self.conv1 = nn.Conv3d( 121 | 3, 122 | 64, 123 | kernel_size=7, 124 | stride=(1, 2, 2), 125 | padding=(3, 3, 3), 126 | bias=False) 127 | self.bn1 = nn.BatchNorm3d(64) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 130 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 131 | self.layer2 = self._make_layer( 132 | block, 128, layers[1], shortcut_type, stride=2) 133 | self.layer3 = self._make_layer( 134 | block, 256, layers[2], shortcut_type, stride=2) 135 | self.layer4 = self._make_layer( 136 | block, 512, layers[3], shortcut_type, stride=2) 137 | last_duration = int(math.ceil(sample_duration / 16)) 138 | last_size = int(math.ceil(sample_size / 32)) 139 | self.avgpool = nn.AvgPool3d( 140 | (last_duration, last_size, last_size), stride=1) 141 | self.fc = nn.Linear(512 * block.expansion, num_classes) 142 | 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv3d): 145 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 146 | elif isinstance(m, nn.BatchNorm3d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | 150 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 151 | downsample = None 152 | if stride != 1 or self.inplanes != planes * block.expansion: 153 | if shortcut_type == 'A': 154 | downsample = partial( 155 | downsample_basic_block, 156 | planes=planes * block.expansion, 157 | stride=stride) 158 | else: 159 | downsample = nn.Sequential( 160 | nn.Conv3d( 161 | self.inplanes, 162 | planes * block.expansion, 163 | kernel_size=1, 164 | stride=stride, 165 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 166 | 167 | layers = [] 168 | layers.append(block(self.inplanes, planes, stride, downsample)) 169 | self.inplanes = planes * block.expansion 170 | for i in range(1, blocks): 171 | layers.append(block(self.inplanes, planes)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self, x): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | x = self.layer4(x) 185 | 186 | x = self.avgpool(x) 187 | 188 | x = x.view(x.size(0), -1) 189 | x = self.fc(x) 190 | 191 | return x 192 | 193 | 194 | def get_fine_tuning_parameters(model, ft_begin_index): 195 | if ft_begin_index == 0: 196 | return model.parameters() 197 | 198 | ft_module_names = [] 199 | for i in range(ft_begin_index, 5): 200 | ft_module_names.append('layer{}'.format(i)) 201 | ft_module_names.append('fc') 202 | 203 | parameters = [] 204 | for k, v in model.named_parameters(): 205 | for ft_module in ft_module_names: 206 | if ft_module in k: 207 | parameters.append({'params': v}) 208 | break 209 | else: 210 | parameters.append({'params': v, 'lr': 0.0}) 211 | 212 | return parameters 213 | 214 | 215 | def resnet10(**kwargs): 216 | """Constructs a ResNet-18 model. 217 | """ 218 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 219 | return model 220 | 221 | 222 | def resnet18(**kwargs): 223 | """Constructs a ResNet-18 model. 224 | """ 225 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 226 | return model 227 | 228 | 229 | def resnet34(**kwargs): 230 | """Constructs a ResNet-34 model. 231 | """ 232 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 233 | return model 234 | 235 | 236 | def resnet50(**kwargs): 237 | """Constructs a ResNet-50 model. 238 | """ 239 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 240 | return model 241 | 242 | 243 | def resnet101(**kwargs): 244 | """Constructs a ResNet-101 model. 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 247 | return model 248 | 249 | 250 | def resnet152(**kwargs): 251 | """Constructs a ResNet-101 model. 252 | """ 253 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 254 | return model 255 | 256 | 257 | def resnet200(**kwargs): 258 | """Constructs a ResNet-101 model. 259 | """ 260 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 261 | return model 262 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/utils/__init__.py -------------------------------------------------------------------------------- /utils/auccur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | import pandas 7 | import os 8 | import argparse 9 | 10 | from models.model import cRNN, get_resnet_3d, CNN, Baseline 11 | from dataloader import FrameDataset, Dataset 12 | import config 13 | from sklearn.metrics import roc_curve, auc 14 | from sklearn.metrics import accuracy_score, roc_auc_score 15 | 16 | 17 | def test(model: nn.Sequential, test_loader: torch.utils.data.DataLoader, model_type, device): 18 | model.eval() 19 | 20 | print('Size of Test Set: ', len(test_loader.dataset)) 21 | 22 | # 准备在测试集上验证模型性能 23 | test_loss = 0 24 | y_gd = [] 25 | frame_y_gd = [] 26 | y_pred = [] 27 | frame_y_pred = [] 28 | 29 | with torch.no_grad(): 30 | if config.net_params.get('our'): 31 | for X, y in tqdm(test_loader, desc='Validating plus frame level'): 32 | X, y = X.to(device), y.to(device) 33 | frame_y = y.view(-1, 1) 34 | frame_y = frame_y.repeat(1, 300) 35 | frame_y = frame_y.flatten() 36 | y_, cnn_y = model(X) 37 | 38 | y_ = y_.argmax(dim=1) 39 | frame_y_ = cnn_y.argmax(dim=1) 40 | 41 | y_gd += y.cpu().numpy().tolist() 42 | y_pred += y_.cpu().numpy().tolist() 43 | frame_y_gd += frame_y.cpu().numpy().tolist() 44 | frame_y_pred += frame_y_.cpu().numpy().tolist() 45 | 46 | test_video_acc = accuracy_score(y_gd, y_pred) 47 | test_video_auc = roc_auc_score(y_gd, y_pred) 48 | test_frame_acc = accuracy_score(frame_y_gd, frame_y_pred) 49 | test_frame_auc = roc_auc_score(frame_y_gd, frame_y_pred) 50 | print('Test video avg loss: %0.4f, acc: %0.2f, auc: %0.2f\n' % ( 51 | test_loss, test_video_acc, test_video_auc)) 52 | print('Test frame avg loss: %0.4f, acc: %0.2f, auc: %0.2f\n' % ( 53 | test_loss, test_frame_acc, test_frame_auc)) 54 | 55 | 56 | else: 57 | for X, y in tqdm(test_loader, desc='Validating plus frame level'): 58 | X, y = X.to(device), y.to(device) 59 | cnn_y = model(X) 60 | frame_y_ = cnn_y.argmax(dim=1) 61 | frame_y_gd += y.cpu().numpy().tolist() 62 | frame_y_pred += frame_y_.cpu().numpy().tolist() 63 | test_frame_acc = accuracy_score(frame_y_gd, frame_y_pred) 64 | test_frame_auc = roc_auc_score(frame_y_gd, frame_y_pred) 65 | print('Test frame avg loss: %0.4f, acc: %0.2f, auc: %0.2f\n' % (test_loss, test_frame_acc, test_frame_auc)) 66 | 67 | return frame_y_gd, frame_y_pred 68 | 69 | 70 | def parse_args(): 71 | parser = argparse.ArgumentParser(usage='python3 main.py -i path/to/data -r path/to/checkpoint') 72 | parser.add_argument('-i', '--data_path', help='path to your datasets', default='/data2/guesthome/wenbop/ffdf_c40') 73 | # parser.add_argument('-i', '--data_path', help='path to your datasets', default='/Users/pu/Desktop/dataset_dlib') 74 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint', 75 | default='/data2/guesthome/wenbop/modules/ff/bi-model_type-baseline_gru_ep-19.pth') 76 | # parser.add_argument('-g', '--gpu', help='visible gpu ids', default='4,5,7') 77 | parser.add_argument('-g', '--gpu', help='visible gpu ids', default='0,1,2,3') 78 | args = parser.parse_args() 79 | return args 80 | 81 | 82 | def draw_auc(): 83 | fpr = dict() 84 | tpr = dict() 85 | roc_auc = dict() 86 | 87 | args = parse_args() 88 | data_path = args.data_path 89 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 90 | raw_data = pandas.read_csv(os.path.join(data_path, '%s.csv' % 'test')) 91 | dataloader = DataLoader(Dataset(raw_data.to_numpy()), **config.dataset_params) 92 | use_cuda = torch.cuda.is_available() 93 | device = torch.device('cuda' if use_cuda else 'cpu') 94 | model = Baseline() 95 | device_count = torch.cuda.device_count() 96 | if device_count > 1: 97 | print('使用{}个GPU训练'.format(device_count)) 98 | model = nn.DataParallel(model) 99 | model.to(device) 100 | ckpt = {} 101 | # 从断点继续训练 102 | if args.restore_from is not None: 103 | ckpt = torch.load(args.restore_from) 104 | # model.load_state_dict(ckpt['net']) 105 | model.load_state_dict(ckpt['model_state_dict']) 106 | print('Model is loaded from %s' % (args.restore_from)) 107 | 108 | y_test, y_score = test(model, dataloader, 'baseline', device) 109 | 110 | for i in range(2): 111 | fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) 112 | roc_auc[i] = auc(fpr[i], tpr[i]) 113 | plt.figure() 114 | lw = 2 115 | plt.plot(fpr[0], tpr[0], color='darkorange', 116 | lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[0]) 117 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') 118 | plt.xlim([0.0, 1.0]) 119 | plt.ylim([0.0, 1.05]) 120 | plt.xlabel('False Positive Rate') 121 | plt.ylabel('True Positive Rate') 122 | plt.title('Receiver operating characteristic example') 123 | plt.legend(loc="lower right") 124 | plt.show() 125 | 126 | 127 | draw_auc() 128 | -------------------------------------------------------------------------------- /utils/aucloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class AUCLoss(torch.nn.Module): 6 | def __init__(self, device, gamma=0.15, alpha=0.6, p=2): 7 | super().__init__() 8 | self.gamma = gamma 9 | self.alpha = alpha 10 | self.p = p 11 | self.device = device 12 | 13 | def forward(self, y_pred, y_true): 14 | pred = torch.sigmoid(y_pred) 15 | pos = pred[torch.where(y_true == 0)] 16 | neg = pred[torch.where(y_true == 1)] 17 | pos = torch.unsqueeze(pos, 0) 18 | neg = torch.unsqueeze(neg, 1) 19 | diff = torch.zeros_like(pos * neg, device=self.device) + pos - neg - self.gamma 20 | masked = diff[torch.where(diff < 0.0)] 21 | auc = torch.mean(torch.pow(-masked, self.p)) 22 | bce = F.binary_cross_entropy_with_logits(y_pred, y_true) 23 | if masked.shape[0] == 0: 24 | loss = bce 25 | else: 26 | loss = self.alpha * bce + (1 - self.alpha) * auc 27 | return loss 28 | -------------------------------------------------------------------------------- /utils/cam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from torch.autograd import Function 4 | from torchvision import models 5 | from torchvision import utils 6 | import cv2 7 | import sys 8 | from collections import OrderedDict 9 | import numpy as np 10 | import argparse 11 | import os 12 | import torch.nn as nn 13 | from models.model import Baseline 14 | import config 15 | 16 | 17 | i=0##testing in what 18 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 19 | use_cuda = torch.cuda.is_available() 20 | device = torch.device('cuda' if use_cuda else 'cpu') 21 | resnet = Baseline(**config.net_params) 22 | device_count = torch.cuda.device_count() 23 | print('使用4个GPU') 24 | model = nn.DataParallel(resnet) 25 | resnet.to(device) 26 | ckpt = torch.load('/data2/guesthome/wenbop/modules/ff/bi-model_type-baseline_gru_ep-19.pth') 27 | #model.load_state_dict(ckpt['net']) 28 | resnet.load_state_dict(ckpt['model_state_dict']) 29 | 30 | # resnet = models.resnet50(pretrained=True)#这里单独加载一个包含全连接层的resnet50模型 31 | image = [] 32 | class FeatureExtractor(): 33 | """ Class for extracting activations and 34 | registering gradients from targetted intermediate layers """ 35 | def __init__(self, model, target_layers): 36 | self.model = model 37 | self.target_layers = target_layers 38 | self.gradients = [] 39 | 40 | def save_gradient(self, grad): 41 | self.gradients.append(grad) 42 | 43 | def __call__(self, x): 44 | outputs = [] 45 | self.gradients = [] 46 | for name, module in self.model._modules.items():##resnet50没有.feature这个特征,直接删除用就可以。 47 | x = module(x) 48 | #print('name=',name) 49 | #print('x.size()=',x.size()) 50 | if name in self.target_layers: 51 | x.register_hook(self.save_gradient) 52 | outputs += [x] 53 | #print('outputs.size()=',x.size()) 54 | #print('len(outputs)',len(outputs)) 55 | return outputs, x 56 | 57 | class ModelOutputs(): 58 | """ Class for making a forward pass, and getting: 59 | 1. The network output. 60 | 2. Activations from intermeddiate targetted layers. 61 | 3. Gradients from intermeddiate targetted layers. """ 62 | def __init__(self, model, target_layers,use_cuda): 63 | self.model = model 64 | self.feature_extractor = FeatureExtractor(self.model, target_layers) 65 | self.cuda = use_cuda 66 | def get_gradients(self): 67 | return self.feature_extractor.gradients 68 | 69 | def __call__(self, x): 70 | target_activations, output = self.feature_extractor(x) 71 | output = output.view(output.size(0), -1) 72 | #print('classfier=',output.size()) 73 | if self.cuda: 74 | output = output.cpu() 75 | cnn = [] 76 | cnn.append(output) 77 | cnn = torch.stack(cnn, dim=0).transpose(0, 1) 78 | rnn_out, _ = resnet.rnn(cnn) 79 | output = resnet.fc_cnn(rnn_out[:,0,:]).cuda()##这里就是为什么我们多加载一个resnet模型进来的原因,因为后面我们命名的model不包含fc层,但是这里又偏偏要使用。# 80 | else: 81 | cnn = [] 82 | cnn.append(output) 83 | cnn = torch.stack(cnn, dim=0).transpose(0, 1) 84 | rnn_out, _ = resnet.rnn(cnn) 85 | output = resnet.fc_cnn(rnn_out[:,0,:])##这里对应use-cuda上更正一些bug,不然用use-cuda的时候会导致类型对不上,这样保证既可以在cpu上运行,gpu上运行也不会出问题. 86 | return target_activations, output 87 | 88 | def preprocess_image(img): 89 | means=[0.485, 0.456, 0.406] 90 | stds=[0.229, 0.224, 0.225] 91 | 92 | preprocessed_img = img.copy()[: , :, ::-1] 93 | for i in range(3): 94 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - means[i] 95 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / stds[i] 96 | preprocessed_img = \ 97 | np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1))) 98 | preprocessed_img = torch.from_numpy(preprocessed_img) 99 | preprocessed_img.unsqueeze_(0) 100 | input = preprocessed_img 101 | input.requires_grad = True 102 | return input 103 | 104 | def show_cam_on_image(img, mask,name): 105 | heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) 106 | heatmap = np.float32(heatmap) / 255 107 | cam = heatmap + np.float32(img) 108 | cam = cam / np.max(cam) 109 | cv2.imwrite("cam/cam_{}.jpg".format(name), np.uint8(255 * cam)) 110 | class GradCam: 111 | def __init__(self, model, target_layer_names, use_cuda): 112 | self.model = model 113 | self.model.eval() 114 | self.cuda = use_cuda 115 | if self.cuda: 116 | self.model = model.cuda() 117 | 118 | self.extractor = ModelOutputs(self.model, target_layer_names, use_cuda) 119 | 120 | def forward(self, input): 121 | return self.model(input) 122 | 123 | def __call__(self, input, index = None): 124 | if self.cuda: 125 | features, output = self.extractor(input.cuda()) 126 | else: 127 | features, output = self.extractor(input) 128 | 129 | if index == None: 130 | index = np.argmax(output.cpu().data.numpy()) 131 | 132 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32) 133 | one_hot[0][index] = 1 134 | one_hot = torch.Tensor(torch.from_numpy(one_hot)) 135 | one_hot.requires_grad = True 136 | if self.cuda: 137 | one_hot = torch.sum(one_hot.cuda() * output) 138 | else: 139 | one_hot = torch.sum(one_hot * output) 140 | 141 | self.model.zero_grad()##features和classifier不包含,可以重新加回去试一试,会报错不包含这个对象。 142 | #self.model.zero_grad() 143 | one_hot.backward(retain_graph=True)##这里适配我们的torch0.4及以上,我用的1.0也可以完美兼容。(variable改成graph即可) 144 | 145 | grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy() 146 | #print('grads_val',grads_val.shape) 147 | target = features[-1] 148 | target = target.cpu().data.numpy()[0, :] 149 | 150 | weights = np.mean(grads_val, axis = (2, 3))[0, :] 151 | #print('weights',weights.shape) 152 | cam = np.zeros(target.shape[1 : ], dtype = np.float32) 153 | #print('cam',cam.shape) 154 | #print('features',features[-1].shape) 155 | #print('target',target.shape) 156 | for i, w in enumerate(weights): 157 | cam += w * target[i, :, :] 158 | 159 | cam = np.maximum(cam, 0) 160 | cam = cv2.resize(cam, (224, 224)) 161 | cam = cam - np.min(cam) 162 | cam = cam / np.max(cam) 163 | return cam 164 | class GuidedBackpropReLUModel: 165 | def __init__(self, model, use_cuda): 166 | self.model = model#这里同理,要的是一个完整的网络,不然最后维度会不匹配。 167 | self.model.eval() 168 | self.cuda = use_cuda 169 | if self.cuda: 170 | self.model = model.cuda() 171 | for module in self.model.named_modules(): 172 | module[1].register_backward_hook(self.bp_relu) 173 | 174 | def bp_relu(self, module, grad_in, grad_out): 175 | if isinstance(module, nn.ReLU): 176 | return (torch.clamp(grad_in[0], min=0.0),) 177 | def forward(self, input): 178 | return self.model(input) 179 | 180 | def __call__(self, input, index = None): 181 | if self.cuda: 182 | output = self.forward(input.cuda()) 183 | else: 184 | output = self.forward(input) 185 | if index == None: 186 | index = np.argmax(output.cpu().data.numpy()) 187 | #print(input.grad) 188 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32) 189 | one_hot[0][index] = 1 190 | one_hot = torch.from_numpy(one_hot) 191 | one_hot.requires_grad = True 192 | if self.cuda: 193 | one_hot = torch.sum(one_hot.cuda() * output) 194 | else: 195 | one_hot = torch.sum(one_hot * output) 196 | #self.model.classifier.zero_grad() 197 | one_hot.backward(retain_graph=True) 198 | output = input.grad.cpu().data.numpy() 199 | output = output[0,:,:,:] 200 | 201 | return output 202 | 203 | def get_args(): 204 | parser = argparse.ArgumentParser() 205 | parser.add_argument('--use-cuda', action='store_true', default=False, 206 | help='Use NVIDIA GPU acceleration') 207 | parser.add_argument('--image-path', type=str, default='/data2/guesthome/wenbop/ffdf/test/0/', 208 | help='Input image path') 209 | args = parser.parse_args() 210 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 211 | if args.use_cuda: 212 | print("Using GPU for acceleration") 213 | else: 214 | print("Using CPU for computation") 215 | 216 | return args 217 | 218 | if __name__ == '__main__': 219 | """ python grad_cam.py 220 | 1. Loads an image with opencv. 221 | 2. Preprocesses it for VGG19 and converts to a pytorch variable. 222 | 3. Makes a forward pass to find the category index with the highest score, 223 | and computes intermediate activations. 224 | Makes the visualization. """ 225 | 226 | args = get_args() 227 | 228 | model = resnet.cnn 229 | grad_cam = GradCam(model , \ 230 | target_layer_names = ["layer4"], use_cuda=args.use_cuda)##这里改成layer4也很简单,我把每层name和size都打印出来了,想看哪层自己直接嵌套就可以了。(最后你会在终端看得到name的) 231 | x=os.walk(args.image_path) 232 | for root, dirs, filename in x: 233 | #print(type(grad_cam)) 234 | print(filename) 235 | for s in filename: 236 | image.append(cv2.imread(args.image_path+s,1)) 237 | #img = cv2.imread(filename, 1) 238 | for img in image: 239 | img = np.float32(cv2.resize(img, (224, 224))) / 255 240 | input = preprocess_image(img) 241 | input.required_grad = True 242 | print('input.size()=',input.size()) 243 | # If None, returns the map for the highest scoring category. 244 | # Otherwise, targets the requested index. 245 | target_index =None 246 | 247 | mask = grad_cam(input, target_index) 248 | i=i+1 249 | show_cam_on_image(img, mask,i) 250 | 251 | gb_model = GuidedBackpropReLUModel(model = resnet, use_cuda=args.use_cuda) 252 | gb = gb_model(input, index=target_index) 253 | if not os.path.exists('gb'): 254 | os.mkdir('gb') 255 | if not os.path.exists('camgb'): 256 | os.mkdir('camgb') 257 | utils.save_image(torch.from_numpy(gb), 'gb/gb_{}.jpg'.format(i)) 258 | cam_mask = np.zeros(gb.shape) 259 | for j in range(0, gb.shape[0]): 260 | cam_mask[j, :, :] = mask 261 | cam_gb = np.multiply(cam_mask, gb) 262 | utils.save_image(torch.from_numpy(cam_gb), 'camgb/cam_gb_{}.jpg'.format(i)) -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | dataset_params = { 2 | 'shuffle': True, 3 | 'num_workers': 4, 4 | 'pin_memory': False 5 | } 6 | 7 | net_params = { 8 | 'use_gru': True, 9 | 'bi_branch': True, 10 | 'dct': False, 11 | 'inputgate': False 12 | } 13 | 14 | resnet_3d_params = { 15 | 'num_classes': 2, 16 | 'model_depth': 50, 17 | 'shortcut_type': 'B', 18 | 'sample_size': img_h, 19 | 'sample_duration': 30 20 | } 21 | 22 | models = { 23 | 1: 'baseline', 24 | 2: 'cRNN', 25 | 3: 'end2end', 26 | 4: 'xception', 27 | 5: 'fwa', 28 | 6: 'cnn', 29 | 7: 'res50', 30 | 8: 'res101', 31 | 9: 'res152' 32 | } 33 | 34 | losses = { 35 | 0: 'CE', 36 | 1: 'AUC', 37 | 2: 'focal' 38 | } 39 | 40 | 41 | gamma = 0.15 42 | 43 | model_type = models.get(1) 44 | loss_type = losses.get(1) 45 | learning_rate = 1e-4 46 | epoches = 20 47 | log_interval = 2 # 打印间隔,默认每2个batch_size打印一次 48 | save_interval = 1 # 模型保存间隔,默认每个epoch保存一次 49 | -------------------------------------------------------------------------------- /utils/drawpics.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from sklearn.metrics import auc 3 | from matplotlib.pyplot import MultipleLocator 4 | 5 | from utils.dataloader import * 6 | 7 | 8 | def draw_compare(): 9 | plt.figure() 10 | # plt.title('', fontsize=20) 11 | plt.xlabel('positive to negative') 12 | plt.ylabel('Frame level AUC') 13 | 14 | plt.plot(['1:10', '1:20', '1:30'], [0.72, 0.75, 0.57], label='Meso4', marker='o') 15 | plt.plot(['1:10', '1:20', '1:30'], [0.60, 0.64, 0.50], label='Xception', marker='s') 16 | plt.plot(['1:10', '1:20', '1:30'], [0.82, 0.79, 0.57], label='DSP-FWA', marker='^') 17 | plt.plot(['1:10', '1:20', '1:30'], [0.63, 0.67, 0.56], label='Capsule', marker='*') 18 | plt.plot(['1:10', '1:20', '1:30'], [0.91, 0.92, 0.78], label='Ours', marker='D') 19 | 20 | plt.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", 21 | mode="expand", borderaxespad=0, ncol=5) 22 | plt.savefig('compare.pdf') 23 | 24 | 25 | def draw_AUC(): 26 | f_fpr = np.load('/home/asus/Code/pvc/m/bs/f_fpr.npy') 27 | f_tpr = np.load('/home/asus/Code/pvc/m/bs/f_tpr.npy') 28 | f_roc_auc = auc(f_fpr, f_tpr) 29 | xcp_f_fpr = np.load('/home/asus/Code/pvc/m/xcp/f_fpr.npy') 30 | xcp_f_tpr = np.load('/home/asus/Code/pvc/m/xcp/f_tpr.npy') 31 | xcp_roc_auc = auc(xcp_f_fpr, xcp_f_tpr) 32 | cap_f_fpr = np.load('/home/asus/Code/pvc/m/cap/f_fpr.npy') 33 | cap_f_tpr = np.load('/home/asus/Code/pvc/m/cap/f_tpr.npy') 34 | cap_roc_auc = auc(cap_f_fpr, cap_f_tpr) 35 | ms4_f_fpr = np.load('/home/asus/Code/pvc/m/ms4/f_fpr.npy') 36 | ms4_f_tpr = np.load('/home/asus/Code/pvc/m/ms4/f_tpr.npy') 37 | ms4_roc_auc = auc(ms4_f_fpr, ms4_f_tpr) 38 | msi_f_fpr = np.load('/home/asus/Code/pvc/m/msi/f_fpr.npy') 39 | msi_f_tpr = np.load('/home/asus/Code/pvc/m/msi/f_tpr.npy') 40 | msi_roc_auc = auc(msi_f_fpr, msi_f_tpr) 41 | plt.figure() 42 | lw = 2 43 | plt.plot(f_fpr, f_tpr, 44 | lw=lw, label='Ours ROC curve (area = %0.2f)' % f_roc_auc) 45 | plt.plot(xcp_f_fpr, xcp_f_tpr, 46 | lw=lw, label='Xception ROC curve (area = %0.2f)' % xcp_roc_auc) 47 | plt.plot(cap_f_fpr, cap_f_tpr, 48 | lw=lw, label='Capsule ROC curve (area = %0.2f)' % cap_roc_auc) 49 | plt.plot(ms4_f_fpr, ms4_f_tpr, 50 | lw=lw, label='Meso4 ROC curve (area = %0.2f)' % ms4_roc_auc) 51 | plt.plot(msi_f_fpr, msi_f_tpr, 52 | lw=lw, label='MesoInception4 ROC curve (area = %0.2f)' % msi_roc_auc) 53 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') 54 | plt.xlim([0.0, 1.0]) 55 | plt.ylim([0.0, 1.05]) 56 | plt.xlabel('False Positive Rate') 57 | plt.ylabel('True Positive Rate') 58 | plt.legend(loc="lower right") 59 | plt.savefig('df_frame.pdf') 60 | 61 | 62 | def draw_WMW(): 63 | # 0.0, 0.2, 0.4, 0.6, 0.8, 1.0 64 | WMW_auc_frame = [0.9032, 0.8772, 0.9383, 0.9293, 0.826, 0.8079] 65 | WMW_acc_frame = [0.957, 0.957, 0.957, 0.957, 0.957, 0.957] 66 | WMW_f1_frame = [0.978, 0.978, 0.978, 0.978, 0.978, 0.978] 67 | WMW_recall_frame = [1, 1, 1, 1, 1, 1] 68 | WMW_auc_video = [0.908, 0.894, 0.9473, 0.9544, 0.7472, 0.8181] 69 | WMW_acc_video = [0.957, 0.957, 0.957, 0.957, 0.957, 0.957] 70 | WMW_f1_video = [0.978, 0.978, 0.978, 0.978, 0.978, 0.978] 71 | WMW_recall_video = [1, 1, 1, 1, 1, 1] 72 | 73 | x = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] 74 | 75 | plt.figure() 76 | plt.plot(np.array(x), np.array(WMW_auc_frame), label='AUC') 77 | # plt.plot(np.array(x), np.array(WMW_auc_video), label='Video level AUC') 78 | plt.plot(np.array(x), np.array(WMW_acc_frame), label='ACC') 79 | plt.plot(np.array(x), np.array(WMW_f1_frame), label='F1') 80 | plt.plot(np.array(x), np.array(WMW_recall_frame), label='recall') 81 | plt.xlim(0.0, 1) 82 | plt.legend(loc="lower center") 83 | plt.xlabel('Margin parameter (gamma) of WMW Loss') 84 | plt.ylabel('Metrics score') 85 | plt.savefig('frame_score.pdf') 86 | plt.show() 87 | 88 | plt.figure() 89 | plt.plot(np.array(x), np.array(WMW_auc_video), label='AUC') 90 | # plt.plot(np.array(x), np.array(WMW_auc_video), label='Video level AUC') 91 | plt.plot(np.array(x), np.array(WMW_acc_video), label='ACC') 92 | plt.plot(np.array(x), np.array(WMW_f1_video), label='F1') 93 | plt.plot(np.array(x), np.array(WMW_recall_video), label='recall') 94 | plt.xlim(0.0, 1) 95 | plt.legend(loc="lower center") 96 | plt.xlabel('Margin parameter (gamma) of WMW Loss') 97 | plt.ylabel('Metrics score') 98 | plt.savefig('video_score.pdf') 99 | plt.show() 100 | 101 | 102 | def draw_auc_compare(): 103 | name = ['Celeb-30', 'Celeb-20', 'Celeb-10'] 104 | our_list = [0.74, 0.95, 0.94] 105 | w_focal = [0.72, 0.95, 0.91] 106 | wo_auc = [0.70, 0.90, 0.87] 107 | 108 | x = np.arange(len(name)) 109 | width = 0.25 110 | 111 | plt.bar(x, our_list, width=width, label='Ours') 112 | plt.bar(x + width, w_focal, width=width, label='Ours with FL', tick_label=name) 113 | plt.bar(x + 2 * width, wo_auc, width=width, label='Ours with BCE') 114 | 115 | # x_major_locator = MultipleLocator(1) 116 | # 把x轴的刻度间隔设置为1,并存在变量里 117 | y_major_locator = MultipleLocator(0.1) 118 | ax = plt.gca() 119 | # ax为两条坐标轴的实例 120 | # ax.xaxis.set_major_locator(x_major_locator) 121 | # 把x轴的主刻度设置为1的倍数 122 | ax.yaxis.set_major_locator(y_major_locator) 123 | # 显示在图形上的值 124 | # for a, b in zip(x, our_list): 125 | # plt.text(a, b + 0.1, b, ha='center', va='bottom') 126 | # for a, b in zip(x, w_focal): 127 | # plt.text(a + width, b + 0.1, b, ha='center', va='bottom') 128 | # for a, b in zip(x, wo_auc): 129 | # plt.text(a + 2 * width, b + 0.1, b, ha='center', va='bottom') 130 | 131 | plt.xticks() 132 | plt.ylim([0.5, 1.0]) 133 | plt.legend(loc="upper left") # 防止label和图像重合显示不出来 134 | # plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 135 | plt.ylabel('AUC score') 136 | # plt.xlabel('line') 137 | # plt.rcParams['savefig.dpi'] = 300 # 图片像素 138 | # plt.rcParams['figure.dpi'] = 300 # 分辨率 139 | # plt.rcParams['figure.figsize'] = (15.0, 8.0) # 尺寸 140 | # plt.title("title") 141 | plt.savefig('w_wo_auc.pdf') 142 | plt.show() 143 | 144 | # x = list(range(len(our_list))) 145 | # total_width, n = 0.8, 3 146 | # width = total_width / n 147 | # 148 | # plt.bar(x, our_list, width=width, label='Our') 149 | # for i in range(len(x)): 150 | # x[i] = x[i] + width 151 | # plt.bar(x, w_focal, width=width, label='Our w Focal loss') 152 | # plt.bar(x, wo_auc, width=width, label='Our w/o AUC loss') 153 | # plt.xticks(np.array(x) - width / 3, name_list) 154 | # plt.legend() 155 | # plt.savefig('w_wo_auc.pdf') 156 | # plt.show() 157 | 158 | # x = 3 159 | # total_width, n = 0.8, 3 # 有多少个类型,只需更改n即可 160 | # width = total_width / n 161 | # x = x - (total_width - width) / 2 162 | # 163 | # plt.bar(x, our_list, width=width, label='Ours') 164 | # plt.bar(x + width, w_focal, width=width, label='Ours with Focal loss') 165 | # plt.bar(x + 2 * width, wo_auc, width=width, label='Ours with BCE ') 166 | # 167 | # plt.xticks() 168 | # plt.legend(loc="upper left") # 防止label和图像重合显示不出来 169 | # # plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 170 | # plt.ylabel('AUC score') 171 | # # plt.xlabel('line') 172 | # # plt.rcParams['savefig.dpi'] = 300 # 图片像素 173 | # # plt.rcParams['figure.dpi'] = 300 # 分辨率 174 | # # plt.rcParams['figure.figsize'] = (15.0, 8.0) # 尺寸 175 | # # plt.title("title") 176 | # plt.savefig('w_wo_auc.pdf') 177 | # plt.show() 178 | 179 | 180 | # draw_WMW() 181 | # draw_auc_compare() 182 | # draw_AUC() 183 | draw_compare() 184 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from sklearn.metrics import accuracy_score 4 | from tqdm import tqdm 5 | from PIL import Image 6 | import pandas 7 | import os 8 | import argparse 9 | import cv2 10 | 11 | from dataloader import Dataset 12 | from models.model import CNNEncoder, RNNDecoder 13 | import config 14 | 15 | def load_imgs_from_video(path: str)->list: 16 | """Extract images from video. 17 | 18 | Args: 19 | path(str): The path of video. 20 | 21 | Returns: 22 | A list of PIL Image. 23 | """ 24 | video_fd = cv2.VideoCapture(path) 25 | video_fd.set(16, True) 26 | # flag 16: 'CV_CAP_PROP_CONVERT_RGB' 27 | # indicating the images should be converted to RGB. 28 | 29 | if not video_fd.isOpened(): 30 | raise ValueError('Invalid path! which is: {}'.format(path)) 31 | 32 | images = [] # type: list[Image] 33 | 34 | success, frame = video_fd.read() 35 | while success: 36 | images.append(Image.fromarray(frame)) 37 | success, frame = video_fd.read() 38 | 39 | return images 40 | 41 | def _eval(checkpoint: str, video_path: str, labels=[])->list: 42 | """Inference the model and return the labels. 43 | 44 | Args: 45 | checkpoint(str): The checkpoint where the model restore from. 46 | path(str): The path of videos. 47 | labels(list): Labels of videos. 48 | 49 | Returns: 50 | A list of labels of the videos. 51 | """ 52 | if not os.path.exists(video_path): 53 | raise ValueError('Invalid path! which is: {}'.format(video_path)) 54 | 55 | print('Loading model from {}'.format(checkpoint)) 56 | use_cuda = torch.cuda.is_available() 57 | device = torch.device('cuda' if use_cuda else 'cpu') 58 | 59 | # Build model 60 | model = nn.Sequential( 61 | CNNEncoder(**config.cnn_encoder_params), 62 | RNNDecoder(**config.rnn_decoder_params) 63 | ) 64 | model.to(device) 65 | model.eval() 66 | 67 | # Load model 68 | ckpt = torch.load(checkpoint) 69 | model.load_state_dict(ckpt['model_state_dict']) 70 | print('Model has been loaded from {}'.format(checkpoint)) 71 | 72 | label_map = [-1] * config.rnn_decoder_params['num_classes'] 73 | # load label map 74 | if 'label_map' in ckpt: 75 | label_map = ckpt['label_map'] 76 | 77 | # Do inference 78 | pred_labels = [] 79 | video_names = os.listdir(video_path) 80 | with torch.no_grad(): 81 | for video in tqdm(video_names, desc='Inferencing'): 82 | # read images from video 83 | images = load_imgs_from_video(os.path.join(video_path, video)) 84 | # apply transform 85 | images = [Dataset.transform(None, img) for img in images] 86 | # stack to tensor, batch size = 1 87 | images = torch.stack(images, dim=0).unsqueeze(0) 88 | # do inference 89 | images = images.to(device) 90 | pred_y = model(images) # type: torch.Tensor 91 | pred_y = pred_y.argmax(dim=1).cpu().numpy().tolist() 92 | pred_labels.append([video, pred_y[0], label_map[pred_y[0]]]) 93 | print(pred_labels[-1]) 94 | 95 | if len(labels) > 0: 96 | acc = accuracy_score(pred_labels, labels) 97 | print('Accuracy: %0.2f' % acc) 98 | 99 | # Save results 100 | pandas.DataFrame(pred_labels).to_csv('result.csv', index=False) 101 | print('Results has been saved to {}'.format('result.csv')) 102 | 103 | return pred_labels 104 | 105 | def parse_args(): 106 | parser = argparse.ArgumentParser(usage='python3 eval.py -i path/to/videos -r path/to/checkpoint') 107 | parser.add_argument('-i', '--video_path', help='path to videos') 108 | parser.add_argument('-r', '--checkpoint', help='path to the checkpoint') 109 | args = parser.parse_args() 110 | return args 111 | 112 | if __name__ == "__main__": 113 | args = parse_args() 114 | _eval(args.checkpoint, args.video_path) 115 | -------------------------------------------------------------------------------- /utils/ff.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ Downloads FaceForensics++ and Deep Fake Detection public data release 3 | Example usage: 4 | see -h or https://github.com/ondyari/FaceForensics 5 | """ 6 | # -*- coding: utf-8 -*- 7 | import argparse 8 | import os 9 | import urllib 10 | import urllib.request 11 | import tempfile 12 | import time 13 | import sys 14 | import json 15 | import random 16 | from tqdm import tqdm 17 | from os.path import join 18 | 19 | 20 | # URLs and filenames 21 | FILELIST_URL = 'misc/filelist.json' 22 | DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json' 23 | DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',] 24 | 25 | # Parameters 26 | DATASETS = { 27 | 'original_youtube_videos': 'misc/downloaded_youtube_videos.zip', 28 | 'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip', 29 | 'original': 'original_sequences/youtube', 30 | 'DeepFakeDetection_original': 'original_sequences/actors', 31 | 'Deepfakes': 'manipulated_sequences/Deepfakes', 32 | 'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection', 33 | 'Face2Face': 'manipulated_sequences/Face2Face', 34 | 'FaceShifter': 'manipulated_sequences/FaceShifter', 35 | 'FaceSwap': 'manipulated_sequences/FaceSwap', 36 | 'NeuralTextures': 'manipulated_sequences/NeuralTextures' 37 | } 38 | ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes', 39 | 'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap', 40 | 'NeuralTextures'] 41 | COMPRESSION = ['raw', 'c23', 'c40'] 42 | TYPE = ['videos', 'masks', 'models'] 43 | SERVERS = ['EU', 'EU2', 'CA'] 44 | 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser( 48 | description='Downloads FaceForensics v2 public data release.', 49 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 50 | ) 51 | parser.add_argument('-o', '--output_path', type=str, default='./', help='Output directory.') 52 | parser.add_argument('-d', '--dataset', type=str, default='Deepfakes', 53 | help='Which dataset to download, either pristine or ' 54 | 'manipulated data or the downloaded youtube ' 55 | 'videos.', 56 | choices=list(DATASETS.keys()) + ['all'] 57 | ) 58 | parser.add_argument('-c', '--compression', type=str, default='c23', 59 | help='Which compression degree. All videos ' 60 | 'have been generated with h264 with a varying ' 61 | 'codec. Raw (c0) videos are lossless compressed.', 62 | choices=COMPRESSION 63 | ) 64 | parser.add_argument('-t', '--type', type=str, default='videos', 65 | help='Which file type, i.e. videos, masks, for our ' 66 | 'manipulation methods, models, for Deepfakes.', 67 | choices=TYPE 68 | ) 69 | parser.add_argument('-n', '--num_videos', type=int, default=None, 70 | help='Select a number of videos number to ' 71 | "download if you don't want to download the full" 72 | ' dataset.') 73 | parser.add_argument('--server', type=str, default='EU', 74 | help='Server to download the data from. If you ' 75 | 'encounter a slow download speed, consider ' 76 | 'changing the server.', 77 | choices=SERVERS 78 | ) 79 | args = parser.parse_args() 80 | 81 | # URLs 82 | server = args.server 83 | if server == 'EU': 84 | server_url = 'http://canis.vc.in.tum.de:8100/' 85 | elif server == 'EU2': 86 | server_url = 'http://kaldir.vc.in.tum.de/faceforensics/' 87 | elif server == 'CA': 88 | server_url = 'http://falas.cmpt.sfu.ca:8100/' 89 | else: 90 | raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS))) 91 | args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf' 92 | args.base_url = server_url + 'v3/' 93 | args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \ 94 | 'Deepfakes/models/' 95 | 96 | return args 97 | 98 | 99 | def download_files(filenames, base_url, output_path, report_progress=True): 100 | os.makedirs(output_path, exist_ok=True) 101 | if report_progress: 102 | filenames = tqdm(filenames) 103 | for filename in filenames: 104 | download_file(base_url + filename, join(output_path, filename)) 105 | 106 | 107 | def reporthook(count, block_size, total_size): 108 | global start_time 109 | if count == 0: 110 | start_time = time.time() 111 | return 112 | duration = time.time() - start_time 113 | progress_size = int(count * block_size) 114 | speed = int(progress_size / (1024 * duration)) 115 | percent = int(count * block_size * 100 / total_size) 116 | sys.stdout.write("\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" % 117 | (percent, progress_size / (1024 * 1024), speed, duration)) 118 | sys.stdout.flush() 119 | 120 | 121 | def download_file(url, out_file, report_progress=False): 122 | out_dir = os.path.dirname(out_file) 123 | if not os.path.isfile(out_file): 124 | fh, out_file_tmp = tempfile.mkstemp(dir=out_dir) 125 | f = os.fdopen(fh, 'w') 126 | f.close() 127 | if report_progress: 128 | urllib.request.urlretrieve(url, out_file_tmp, 129 | reporthook=reporthook) 130 | else: 131 | urllib.request.urlretrieve(url, out_file_tmp) 132 | os.rename(out_file_tmp, out_file) 133 | else: 134 | tqdm.write('WARNING: skipping download of existing file ' + out_file) 135 | 136 | 137 | def main(args): 138 | # TOS 139 | print('By pressing any key to continue you confirm that you have agreed '\ 140 | 'to the FaceForensics terms of use as described at:') 141 | print(args.tos_url) 142 | print('***') 143 | print('Press any key to continue, or CTRL-C to exit.') 144 | _ = input('') 145 | 146 | # Extract arguments 147 | c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS 148 | c_type = args.type 149 | c_compression = args.compression 150 | num_videos = args.num_videos 151 | output_path = args.output_path 152 | os.makedirs(output_path, exist_ok=True) 153 | 154 | # Check for special dataset cases 155 | for dataset in c_datasets: 156 | dataset_path = DATASETS[dataset] 157 | # Special cases 158 | if 'original_youtube_videos' in dataset: 159 | # Here we download the original youtube videos zip file 160 | print('Downloading original youtube videos.') 161 | if not 'info' in dataset_path: 162 | print('Please be patient, this may take a while (~40gb)') 163 | suffix = '' 164 | else: 165 | suffix = 'info' 166 | download_file(args.base_url + '/' + dataset_path, 167 | out_file=join(output_path, 168 | 'downloaded_videos{}.zip'.format( 169 | suffix)), 170 | report_progress=True) 171 | return 172 | 173 | # Else: regular datasets 174 | print('Downloading {} of dataset "{}"'.format( 175 | c_type, dataset_path 176 | )) 177 | 178 | # Get filelists and video lenghts list from server 179 | if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path: 180 | filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' + 181 | DEEPFEAKES_DETECTION_URL).read().decode("utf-8")) 182 | if 'actors' in dataset_path: 183 | filelist = filepaths['actors'] 184 | else: 185 | filelist = filepaths['DeepFakesDetection'] 186 | elif 'original' in dataset_path: 187 | # Load filelist from server 188 | file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' + 189 | FILELIST_URL).read().decode("utf-8")) 190 | filelist = [] 191 | for pair in file_pairs: 192 | filelist += pair 193 | else: 194 | # Load filelist from server 195 | file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' + 196 | FILELIST_URL).read().decode("utf-8")) 197 | # Get filelist 198 | filelist = [] 199 | for pair in file_pairs: 200 | filelist.append('_'.join(pair)) 201 | if c_type != 'models': 202 | filelist.append('_'.join(pair[::-1])) 203 | # Maybe limit number of videos for download 204 | if num_videos is not None and num_videos > 0: 205 | print('Downloading the first {} videos'.format(num_videos)) 206 | filelist = filelist[:num_videos] 207 | 208 | # Server and local paths 209 | dataset_videos_url = args.base_url + '{}/{}/{}/'.format( 210 | dataset_path, c_compression, c_type) 211 | dataset_mask_url = args.base_url + '{}/{}/videos/'.format( 212 | dataset_path, 'masks', c_type) 213 | 214 | if c_type == 'videos': 215 | dataset_output_path = join(output_path, dataset_path, c_compression, 216 | c_type) 217 | print('Output path: {}'.format(dataset_output_path)) 218 | filelist = [filename + '.mp4' for filename in filelist] 219 | download_files(filelist, dataset_videos_url, dataset_output_path) 220 | elif c_type == 'masks': 221 | dataset_output_path = join(output_path, dataset_path, c_type, 222 | 'videos') 223 | print('Output path: {}'.format(dataset_output_path)) 224 | if 'original' in dataset: 225 | if args.dataset != 'all': 226 | print('Only videos available for original data. Aborting.') 227 | return 228 | else: 229 | print('Only videos available for original data. ' 230 | 'Skipping original.\n') 231 | continue 232 | if 'FaceShifter' in dataset: 233 | print('Masks not available for FaceShifter. Aborting.') 234 | return 235 | filelist = [filename + '.mp4' for filename in filelist] 236 | download_files(filelist, dataset_mask_url, dataset_output_path) 237 | 238 | # Else: models for deepfakes 239 | else: 240 | if dataset != 'Deepfakes' and c_type == 'models': 241 | print('Models only available for Deepfakes. Aborting') 242 | return 243 | dataset_output_path = join(output_path, dataset_path, c_type) 244 | print('Output path: {}'.format(dataset_output_path)) 245 | 246 | # Get Deepfakes models 247 | for folder in tqdm(filelist): 248 | folder_filelist = DEEPFAKES_MODEL_NAMES 249 | 250 | # Folder paths 251 | folder_base_url = args.deepfakes_model_url + folder + '/' 252 | folder_dataset_output_path = join(dataset_output_path, 253 | folder) 254 | download_files(folder_filelist, folder_base_url, 255 | folder_dataset_output_path, 256 | report_progress=False) # already done 257 | 258 | 259 | if __name__ == "__main__": 260 | args = parse_args() 261 | main(args) 262 | -------------------------------------------------------------------------------- /utils/focalloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class BCEFocalLoss(torch.nn.Module): 8 | 9 | def __init__(self, gamma=2, alpha=0.6, reduction='elementwise_mean'): 10 | super().__init__() 11 | self.gamma = gamma 12 | self.alpha = alpha 13 | self.reduction = reduction 14 | 15 | def forward(self, _input, target): 16 | pt = torch.sigmoid(_input) 17 | # pt = _input 18 | alpha = self.alpha 19 | loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \ 20 | (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) 21 | if self.reduction == 'elementwise_mean': 22 | loss = torch.mean(loss) 23 | elif self.reduction == 'sum': 24 | loss = torch.sum(loss) 25 | return loss 26 | 27 | 28 | class FocalLoss(nn.Module): 29 | def __init__(self, gamma=0, alpha=None, size_average=True): 30 | super(FocalLoss, self).__init__() 31 | self.gamma = gamma 32 | self.alpha = alpha 33 | if isinstance(alpha, (float, int, long)): self.alpha = torch.Tensor([alpha, 1 - alpha]) 34 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) 35 | self.size_average = size_average 36 | 37 | def forward(self, input, target): 38 | if input.dim() > 2: 39 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 40 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 41 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 42 | target = target.view(-1, 1) 43 | 44 | logpt = F.log_softmax(input) 45 | logpt = logpt.gather(1, target) 46 | logpt = logpt.view(-1) 47 | pt = Variable(logpt.data.exp()) 48 | 49 | if self.alpha is not None: 50 | if self.alpha.type() != input.data.type(): 51 | self.alpha = self.alpha.type_as(input.data) 52 | at = self.alpha.gather(0, target.data.view(-1)) 53 | logpt = logpt * Variable(at) 54 | 55 | loss = -1 * (1 - pt) ** self.gamma * logpt 56 | if self.size_average: 57 | return loss.mean() 58 | else: 59 | return loss.sum() 60 | -------------------------------------------------------------------------------- /utils/gradcam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.autograd import Function 6 | 7 | import numpy as np 8 | import cv2 9 | 10 | from models.model import Baseline 11 | 12 | import os 13 | import argparse 14 | import config 15 | 16 | 17 | class FeatureExtractor(): 18 | """ Class for extracting activations and 19 | registering gradients from targetted intermediate layers """ 20 | 21 | def __init__(self, model, target_layers): 22 | self.model = model 23 | self.target_layers = target_layers 24 | self.gradients = [] 25 | 26 | def save_gradient(self, grad): 27 | self.gradients.append(grad) 28 | 29 | def __call__(self, x): 30 | outputs = [] 31 | self.gradients = [] 32 | for name, module in self.model._modules.items(): 33 | x = module(x) 34 | if name in self.target_layers: 35 | x.register_hook(self.save_gradient) 36 | outputs += [x] 37 | return outputs, x 38 | 39 | 40 | class ModelOutputs(): 41 | """ Class for making a forward pass, and getting: 42 | 1. The network output. 43 | 2. Activations from intermeddiate targetted layers. 44 | 3. Gradients from intermeddiate targetted layers. """ 45 | 46 | def __init__(self, model, feature_module, target_layers): 47 | self.model = model 48 | self.feature_module = feature_module 49 | self.feature_extractor = FeatureExtractor(self.feature_module, target_layers) 50 | 51 | def get_gradients(self): 52 | return self.feature_extractor.gradients 53 | 54 | def __call__(self, x): 55 | target_activations = [] 56 | for name, module in self.model._modules.items(): 57 | if module == self.feature_module: 58 | target_activations, x = self.feature_extractor(x) 59 | elif "avgpool" in name.lower(): 60 | x = module(x) 61 | x = x.view(x.size(0), -1) 62 | else: 63 | x = module(x) 64 | 65 | return target_activations, x 66 | 67 | 68 | def preprocess_image(img): 69 | means = [0.485, 0.456, 0.406] 70 | stds = [0.229, 0.224, 0.225] 71 | 72 | preprocessed_img = img.copy()[:, :, ::-1] 73 | for i in range(3): 74 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - means[i] 75 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / stds[i] 76 | preprocessed_img = \ 77 | np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1))) 78 | preprocessed_img = torch.from_numpy(preprocessed_img) 79 | preprocessed_img.unsqueeze_(0) 80 | input = preprocessed_img.requires_grad_(True) 81 | return input 82 | 83 | 84 | def show_cam_on_image(img, mask, path): 85 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 86 | heatmap = np.float32(heatmap) / 255 87 | cam = heatmap + np.float32(img) 88 | cam = cam / np.max(cam) 89 | cv2.imwrite(path + "_cam.jpg", np.uint8(255 * cam)) 90 | 91 | 92 | class GradCam: 93 | def __init__(self, model, feature_module, target_layer_names, use_cuda): 94 | self.model = model 95 | self.feature_module = feature_module 96 | self.model.eval() 97 | self.cuda = use_cuda 98 | if self.cuda: 99 | self.model = model.cuda() 100 | 101 | self.extractor = ModelOutputs(self.model, self.feature_module, target_layer_names) 102 | 103 | def forward(self, input): 104 | return self.model(input) 105 | 106 | def __call__(self, input, index=None): 107 | if self.cuda: 108 | features, output = self.extractor(input.cuda()) 109 | else: 110 | features, output = self.extractor(input) 111 | 112 | if index == None: 113 | index = np.argmax(output.cpu().data.numpy()) 114 | 115 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 116 | one_hot[0][index] = 1 117 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 118 | if self.cuda: 119 | one_hot = torch.sum(one_hot.cuda() * output) 120 | else: 121 | one_hot = torch.sum(one_hot * output) 122 | 123 | self.feature_module.zero_grad() 124 | self.model.zero_grad() 125 | one_hot.backward(retain_graph=True) 126 | 127 | grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy() 128 | 129 | target = features[-1] 130 | target = target.cpu().data.numpy()[0, :] 131 | 132 | weights = np.mean(grads_val, axis=(2, 3))[0, :] 133 | cam = np.zeros(target.shape[1:], dtype=np.float32) 134 | 135 | for i, w in enumerate(weights): 136 | cam += w * target[i, :, :] 137 | 138 | cam = np.maximum(cam, 0) 139 | cam = cv2.resize(cam, input.shape[2:]) 140 | cam = cam - np.min(cam) 141 | cam = cam / np.max(cam) 142 | return cam 143 | 144 | 145 | class GuidedBackpropReLU(Function): 146 | 147 | @staticmethod 148 | def forward(self, input): 149 | positive_mask = (input > 0).type_as(input) 150 | output = torch.addcmul(torch.zeros(input.size()).type_as(input), input, positive_mask) 151 | self.save_for_backward(input, output) 152 | return output 153 | 154 | @staticmethod 155 | def backward(self, grad_output): 156 | input, output = self.saved_tensors 157 | grad_input = None 158 | 159 | positive_mask_1 = (input > 0).type_as(grad_output) 160 | positive_mask_2 = (grad_output > 0).type_as(grad_output) 161 | grad_input = torch.addcmul(torch.zeros(input.size()).type_as(input), 162 | torch.addcmul(torch.zeros(input.size()).type_as(input), grad_output, 163 | positive_mask_1), positive_mask_2) 164 | 165 | return grad_input 166 | 167 | 168 | class GuidedBackpropReLUModel: 169 | def __init__(self, model, use_cuda): 170 | self.model = model 171 | self.model.eval() 172 | self.cuda = use_cuda 173 | if self.cuda: 174 | self.model = model.cuda() 175 | 176 | def recursive_relu_apply(module_top): 177 | for idx, module in module_top._modules.items(): 178 | recursive_relu_apply(module) 179 | if module.__class__.__name__ == 'ReLU': 180 | module_top._modules[idx] = GuidedBackpropReLU.apply 181 | 182 | # replace ReLU with GuidedBackpropReLU 183 | recursive_relu_apply(self.model) 184 | 185 | def forward(self, input): 186 | return self.model(input) 187 | 188 | def __call__(self, input, index=None): 189 | if self.cuda: 190 | output = self.forward(input.cuda()) 191 | else: 192 | output = self.forward(input) 193 | 194 | if index == None: 195 | index = np.argmax(output.cpu().data.numpy()) 196 | 197 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 198 | one_hot[0][index] = 1 199 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 200 | if self.cuda: 201 | one_hot = torch.sum(one_hot.cuda() * output) 202 | else: 203 | one_hot = torch.sum(one_hot * output) 204 | 205 | # self.model.features.zero_grad() 206 | # self.model.classifier.zero_grad() 207 | one_hot.backward(retain_graph=True) 208 | 209 | output = input.grad.cpu().data.numpy() 210 | output = output[0, :, :, :] 211 | 212 | return output 213 | 214 | 215 | def deprocess_image(img): 216 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 217 | img = img - np.mean(img) 218 | img = img / (np.std(img) + 1e-5) 219 | img = img * 0.1 220 | img = img + 0.5 221 | img = np.clip(img, 0, 1) 222 | return np.uint8(img * 255) 223 | 224 | 225 | def parse_args(): 226 | parser = argparse.ArgumentParser(usage='python3 main.py -i path/to/data -r path/to/checkpoint') 227 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint', 228 | default='/data2/guesthome/wenbop/modules/new-bi-model_type-baseline_gru_ep-17.pth') 229 | # parser.add_argument('-g', '--gpu', help='visible gpu ids', default='4,5,7') 230 | parser.add_argument('-i', '--image-path', type=str, default='/data2/guesthome/wenbop/ffdf/test/0/', 231 | help='Input image path') 232 | parser.add_argument('-g', '--gpu', help='visible gpu ids', default='0,1,2,3') 233 | args = parser.parse_args() 234 | return args 235 | 236 | 237 | if __name__ == "__main__": 238 | args = parse_args() 239 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 240 | use_cuda = torch.cuda.is_available() 241 | device = torch.device('cuda' if use_cuda else 'cpu') 242 | model = Baseline(**config.net_params) 243 | device_count = torch.cuda.device_count() 244 | if device_count > 1: 245 | print('使用{}个GPU训练'.format(device_count)) 246 | model = nn.DataParallel(model) 247 | model.to(device) 248 | ckpt = {} 249 | # 从断点继续训练 250 | if args.restore_from is not None: 251 | ckpt = torch.load(args.restore_from) 252 | # model.load_state_dict(ckpt['net']) 253 | model.load_state_dict(ckpt['model_state_dict']) 254 | print('Model is loaded from %s' % (args.restore_from)) 255 | 256 | model = model.module.cnn 257 | grad_cam = GradCam(model=model, feature_module=model[7], target_layer_names=["2"], use_cuda=True) 258 | 259 | img = cv2.imread(args.image_path, 1) 260 | img = np.float32(cv2.resize(img, (224, 224))) / 255 261 | input = preprocess_image(img) 262 | 263 | # If None, returns the map for the highest scoring category. 264 | # Otherwise, targets the requested index. 265 | target_index = 0 266 | mask = grad_cam(input, target_index) 267 | 268 | show_cam_on_image(img, mask, args.image_path.split('/')[-1]) 269 | 270 | gb_model = GuidedBackpropReLUModel(model=model, use_cuda=True) 271 | 272 | gb = gb_model(input, index=target_index) 273 | gb = gb.transpose((1, 2, 0)) 274 | cam_mask = cv2.merge([mask, mask, mask]) 275 | cam_gb = deprocess_image(cam_mask * gb) 276 | gb = deprocess_image(gb) 277 | 278 | cv2.imwrite('./gram/' + args.image_path.split('/')[-1] + '_gb.jpg', gb) 279 | cv2.imwrite('./gram/' + args.image_path.split('/')[-1] + '_cam_gb.jpg', cam_gb) 280 | -------------------------------------------------------------------------------- /utils/mmod_human_face_detector.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/utils/mmod_human_face_detector.dat -------------------------------------------------------------------------------- /utils/test.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | "953", 4 | "974" 5 | ], 6 | [ 7 | "012", 8 | "026" 9 | ], 10 | [ 11 | "078", 12 | "955" 13 | ], 14 | [ 15 | "623", 16 | "630" 17 | ], 18 | [ 19 | "919", 20 | "015" 21 | ], 22 | [ 23 | "367", 24 | "371" 25 | ], 26 | [ 27 | "847", 28 | "906" 29 | ], 30 | [ 31 | "529", 32 | "633" 33 | ], 34 | [ 35 | "418", 36 | "507" 37 | ], 38 | [ 39 | "227", 40 | "169" 41 | ], 42 | [ 43 | "389", 44 | "480" 45 | ], 46 | [ 47 | "821", 48 | "812" 49 | ], 50 | [ 51 | "670", 52 | "661" 53 | ], 54 | [ 55 | "158", 56 | "379" 57 | ], 58 | [ 59 | "423", 60 | "421" 61 | ], 62 | [ 63 | "352", 64 | "319" 65 | ], 66 | [ 67 | "579", 68 | "701" 69 | ], 70 | [ 71 | "488", 72 | "399" 73 | ], 74 | [ 75 | "695", 76 | "422" 77 | ], 78 | [ 79 | "288", 80 | "321" 81 | ], 82 | [ 83 | "705", 84 | "707" 85 | ], 86 | [ 87 | "306", 88 | "278" 89 | ], 90 | [ 91 | "865", 92 | "739" 93 | ], 94 | [ 95 | "995", 96 | "233" 97 | ], 98 | [ 99 | "755", 100 | "759" 101 | ], 102 | [ 103 | "467", 104 | "462" 105 | ], 106 | [ 107 | "314", 108 | "347" 109 | ], 110 | [ 111 | "741", 112 | "731" 113 | ], 114 | [ 115 | "970", 116 | "973" 117 | ], 118 | [ 119 | "634", 120 | "660" 121 | ], 122 | [ 123 | "494", 124 | "445" 125 | ], 126 | [ 127 | "706", 128 | "479" 129 | ], 130 | [ 131 | "186", 132 | "170" 133 | ], 134 | [ 135 | "176", 136 | "190" 137 | ], 138 | [ 139 | "380", 140 | "358" 141 | ], 142 | [ 143 | "214", 144 | "255" 145 | ], 146 | [ 147 | "454", 148 | "527" 149 | ], 150 | [ 151 | "425", 152 | "485" 153 | ], 154 | [ 155 | "388", 156 | "308" 157 | ], 158 | [ 159 | "384", 160 | "932" 161 | ], 162 | [ 163 | "035", 164 | "036" 165 | ], 166 | [ 167 | "257", 168 | "420" 169 | ], 170 | [ 171 | "924", 172 | "917" 173 | ], 174 | [ 175 | "114", 176 | "102" 177 | ], 178 | [ 179 | "732", 180 | "691" 181 | ], 182 | [ 183 | "550", 184 | "452" 185 | ], 186 | [ 187 | "280", 188 | "249" 189 | ], 190 | [ 191 | "842", 192 | "714" 193 | ], 194 | [ 195 | "625", 196 | "650" 197 | ], 198 | [ 199 | "024", 200 | "073" 201 | ], 202 | [ 203 | "044", 204 | "945" 205 | ], 206 | [ 207 | "896", 208 | "128" 209 | ], 210 | [ 211 | "862", 212 | "047" 213 | ], 214 | [ 215 | "607", 216 | "683" 217 | ], 218 | [ 219 | "517", 220 | "521" 221 | ], 222 | [ 223 | "682", 224 | "669" 225 | ], 226 | [ 227 | "138", 228 | "142" 229 | ], 230 | [ 231 | "552", 232 | "851" 233 | ], 234 | [ 235 | "376", 236 | "381" 237 | ], 238 | [ 239 | "000", 240 | "003" 241 | ], 242 | [ 243 | "048", 244 | "029" 245 | ], 246 | [ 247 | "724", 248 | "725" 249 | ], 250 | [ 251 | "608", 252 | "675" 253 | ], 254 | [ 255 | "386", 256 | "154" 257 | ], 258 | [ 259 | "220", 260 | "219" 261 | ], 262 | [ 263 | "801", 264 | "855" 265 | ], 266 | [ 267 | "161", 268 | "141" 269 | ], 270 | [ 271 | "949", 272 | "868" 273 | ], 274 | [ 275 | "880", 276 | "135" 277 | ], 278 | [ 279 | "429", 280 | "404" 281 | ] 282 | ] -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas 3 | from PIL import Image 4 | 5 | from utils.dataloader import Dataset 6 | from make_train_test import * 7 | from meso.meso import * 8 | 9 | 10 | def load_datas(src_path, files=[]): 11 | datas = [] 12 | for file in files: 13 | img = Image.open(os.path.join(src_path, file)).convert('RGB') 14 | img.save("./images/" + file) 15 | img = img.resize((64, 64), Image.ANTIALIAS) 16 | data = np.array(img) 17 | data = np.transpose(data, (2, 0, 1)) 18 | datas.append(data) 19 | return np.array(datas) 20 | 21 | 22 | def video_frame_face_extractor(path, output): 23 | import dlib 24 | face_detector = dlib.cnn_face_detection_model_v1('./mmod_human_face_detector.dat') 25 | video_fd = cv2.VideoCapture(path) 26 | if not video_fd.isOpened(): 27 | print('Skpped: {}'.format(path)) 28 | 29 | frame_index = 0 30 | success, frame = video_fd.read() 31 | while success: 32 | frame_path = os.path.join(output + '/frame/%s_%d.jpg' % (path.split('/')[-1], frame_index)) 33 | cv2.imwrite(frame_path, frame) 34 | img_path = os.path.join(output + '/face/%s_%d.jpg' % (path.split('/')[-1], frame_index)) 35 | height, width = frame.shape[:2] 36 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 37 | faces = face_detector(gray, 1) 38 | if len(faces): 39 | # For now only take biggest face 40 | face = faces[0].rect 41 | x, y, size = get_boundingbox(face, width, height) 42 | # generate cropped image 43 | cropped_face = frame[y:y + size, x:x + size] 44 | cv2.imwrite(img_path, cropped_face) 45 | 46 | frame_index += 1 47 | success, frame = video_fd.read() 48 | 49 | video_fd.release() 50 | 51 | 52 | def list_file(path, label): 53 | list = [] 54 | for file in os.listdir(path): 55 | list.append([path + '/' + file, label]) 56 | 57 | return list 58 | 59 | 60 | def dataset_size(path): 61 | for file in os.listdir(path + '/0/'): 62 | img = cv2.imread(path + '/0/' + file) 63 | print(np.array(img).shape) 64 | 65 | 66 | def frame_range(src_dir): 67 | Celeb_real = list_file(src_dir + '/Celeb-real', 1) 68 | Celeb_synthesis = list_file(src_dir + '/Celeb-synthesis', 0) 69 | YouTube_real = list_file(src_dir + '/YouTube-real', 1) 70 | 71 | frame_m = [] 72 | 73 | for [file, _] in Celeb_real: 74 | video = cv2.VideoCapture(os.path.join(src_dir, '/Celeb-real', file)) 75 | frame_num = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 76 | frame_m.append(frame_num) 77 | print(file) 78 | print(frame_num) 79 | print('---------------') 80 | for [file, _] in Celeb_synthesis: 81 | video = cv2.VideoCapture(os.path.join(src_dir, '/Celeb-synthesis', file)) 82 | frame_num = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 83 | frame_m.append(frame_num) 84 | print(file) 85 | print(frame_num) 86 | print('---------------') 87 | for [file, _] in YouTube_real: 88 | video = cv2.VideoCapture(os.path.join(src_dir, '/YouTube-real', file)) 89 | frame_num = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 90 | frame_m.append(frame_num) 91 | print(file) 92 | print(frame_num) 93 | print('---------------') 94 | frame_m = np.array(frame_m) 95 | print('====================') 96 | print(np.max(np.array(frame_m))) 97 | print(np.min(np.array(frame_m))) 98 | print(frame_m.shape) 99 | print(np.median(frame_m)) 100 | print(np.mean(frame_m)) 101 | 102 | 103 | def build_AUC_loss(outputs, labels, gamma, power_p): 104 | posi_idx = tf.where(tf.equal(labels, 1.0)) 105 | neg_idx = tf.where(tf.equal(labels, -1.0)) 106 | prdictions = tf.nn.softmax(outputs) 107 | posi_predict = tf.gather(prdictions, posi_idx) 108 | posi_size = tf.shape(posi_predict)[0] 109 | neg_predict = tf.gather(prdictions, neg_idx) 110 | neg_size = tf.shape(posi_predict)[0] 111 | posi_neg_diff = tf.reshape( 112 | -(tf.matmul(posi_predict, tf.ones([1, neg_size])) - 113 | tf.matmul(tf.ones([posi_size, 1]), tf.reshape(neg_predict, [-1, neg_size])) - gamma), 114 | [-1, 1]) 115 | posi_neg_diff = tf.where(tf.greater(posi_neg_diff, 0), posi_neg_diff, tf.zeros([posi_size * neg_size, 1])) 116 | posi_neg_diff = tf.pow(posi_neg_diff, power_p) 117 | loss_approx_auc = tf.reduce_mean(posi_neg_diff) 118 | return loss_approx_auc 119 | 120 | 121 | def auc_loss(y_pred, y_true, gamma, p=2): 122 | pos = tf.boolean_mask(y_pred, tf.cast(y_true, tf.bool)) 123 | neg = tf.boolean_mask(y_pred, ~tf.cast(y_true, tf.bool)) 124 | pos = tf.expand_dims(pos, 0) 125 | neg = tf.expand_dims(neg, 1) 126 | difference = tf.zeros_like(pos * neg) + pos - neg - gamma 127 | masked = tf.boolean_mask(difference, difference < 0.0) 128 | return tf.reduce_sum(tf.pow(-masked, p)) 129 | 130 | 131 | def AUC_loss(y_pred, y_true, device, gamma, p=2): 132 | pred = torch.sigmoid(y_pred) 133 | pos = pred[torch.where(y_true == 0)] 134 | neg = pred[torch.where(y_true == 1)] 135 | pos = torch.unsqueeze(pos, 0) 136 | neg = torch.unsqueeze(neg, 1) 137 | diff = torch.zeros_like(pos * neg, device=device) + pos - neg - gamma 138 | masked = diff[torch.where(diff < 0.0)] 139 | return torch.mean(torch.pow(-masked, p)) 140 | 141 | 142 | # def AUC_loss(outputs, labels, device, gamma, p=2): 143 | # predictions = torch.sigmoid(outputs) 144 | # pos_predict = predictions[torch.where(labels == 0)] 145 | # neg_predict = predictions[torch.where(labels == 1)] 146 | # pos_size = pos_predict.shape[0] 147 | # neg_size = neg_predict.shape[0] 148 | # # if pos_size == 0 or neg_size == 0: 149 | # # return 0 150 | # # else: 151 | # if pos_size != 0 and neg_size != 0: 152 | # pos_neg_diff = -(torch.matmul(pos_predict, torch.ones([1, neg_size], device=device)) - 153 | # torch.matmul(torch.ones([pos_size, 1], device=device), 154 | # torch.reshape(neg_predict, [-1, neg_size])) 155 | # - gamma) 156 | # pos_neg_diff = torch.reshape(pos_neg_diff, [-1, 1]) 157 | # pos_neg_diff = torch.where(torch.gt(pos_neg_diff, 0), pos_neg_diff, torch.zeros([pos_size * neg_size, 1], 158 | # device=device)) 159 | # elif neg_size == 0: 160 | # pos_neg_diff = -(pos_predict - gamma) 161 | # pos_neg_diff = torch.where(torch.gt(pos_neg_diff, 0), pos_neg_diff, torch.zeros([pos_size, 1], device=device)) 162 | # else: 163 | # pos_neg_diff = -(-neg_predict - gamma) 164 | # pos_neg_diff = torch.where(torch.gt(pos_neg_diff, 0), pos_neg_diff, torch.zeros([neg_size, 1], device=device)) 165 | # 166 | # pos_neg_diff = torch.pow(pos_neg_diff, p) 167 | # 168 | # loss_approx_auc = torch.mean(pos_neg_diff) 169 | # return loss_approx_auc 170 | 171 | 172 | # def AUC_loss(y_, y, device, gamma, p=2): 173 | # X = y_[torch.where(y == 0)] 174 | # Y = y_[torch.where(y == 1)] 175 | # loss = torch.zeros(1, requires_grad=True, device=device) 176 | # if X.shape[0] == 0: 177 | # Y = torch.max(Y, 1)[0] 178 | # for j in Y: 179 | # if -j < gamma: 180 | # loss = (-(- j - gamma)) ** p + loss 181 | # if Y.shape[0] == 0: 182 | # X = torch.max(X, 1)[0] 183 | # for i in X: 184 | # if i < gamma: 185 | # loss = (-(i - gamma)) ** p + loss 186 | # if X.shape[0] != 0 and Y.shape[0] != 0: 187 | # X = torch.max(X, 1)[0] 188 | # Y = torch.max(Y, 1)[0] 189 | # for i in X: 190 | # for j in Y: 191 | # if i - j < gamma: 192 | # loss = (-(i - j - gamma)) ** p + loss 193 | # return loss 194 | 195 | 196 | def merge_labels_to_ckpt(ck_path: str, train_file: str): 197 | """Merge labels to a checkpoint file. 198 | 199 | Args: 200 | ck_path(str): path to checkpoint file 201 | train_file(str): path to train set index file, eg. train.csv 202 | 203 | Return: 204 | This function will create a {ck_path}_patched.pth file. 205 | """ 206 | # load model 207 | print('Loading checkpoint') 208 | ckpt = torch.load(ck_path) 209 | 210 | # load train files 211 | print('Loading dataset') 212 | raw_data = pandas.read_csv(train_file) 213 | train_set = Dataset(raw_data.to_numpy()) 214 | 215 | # patch file name 216 | print('Patching') 217 | patch_path = ck_path.replace('.pth', '') + '_patched.pth' 218 | 219 | ck_dict = {'label_map': train_set.labels} 220 | names = ['epoch', 'model_state_dict', 'optimizer_state_dict'] 221 | for name in names: 222 | ck_dict[name] = ckpt[name] 223 | 224 | torch.save(ck_dict, patch_path) 225 | print('Patched checkpoint has been saved to {}'.format(patch_path)) 226 | 227 | 228 | def tensor2im(input_image, imtype=np.uint8): 229 | """"将tensor的数据类型转成numpy类型,并反归一化. 230 | 231 | Parameters: 232 | input_image (tensor) -- 输入的图像tensor数组 233 | imtype (type) -- 转换后的numpy的数据类型 234 | """ 235 | mean = [0.485, 0.456, 0.406] # 自己设置的 236 | std = [0.229, 0.224, 0.225] # 自己设置的 237 | if not isinstance(input_image, np.ndarray): 238 | if isinstance(input_image, torch.Tensor): # get the data from a variable 239 | image_tensor = input_image.data 240 | else: 241 | return input_image 242 | image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array 243 | if image_numpy.shape[0] == 1: # grayscale to RGB 244 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 245 | for i in range(len(mean)): 246 | image_numpy[i] = image_numpy[i] * std[i] + mean[i] 247 | image_numpy = image_numpy * 255 248 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) # post-processing: tranpose and scaling 249 | else: # if it is a numpy array, do nothing 250 | image_numpy = input_image 251 | return image_numpy.astype(imtype) 252 | 253 | 254 | def save_img(im, path): 255 | """im可是没经过任何处理的tensor类型的数据,将数据存储到path中 256 | 257 | Parameters: 258 | im (tensor) -- 输入的图像tensor数组 259 | path (str) -- 图像保存的路径 260 | size (int) -- 一行有size张图,最好是2的倍数 261 | """ 262 | # im_grid = torchvision.utils.make_grid(im, size) #将batchsize的图合成一张图 263 | im_numpy = tensor2im(im) # 转成numpy类型并反归一化 264 | im_array = Image.fromarray(im_numpy) 265 | im_array.save(path) 266 | 267 | 268 | def new_path(path): 269 | if not os.path.exists(path): 270 | try: 271 | os.mkdir(path) 272 | except Exception: 273 | os.makedirs(path) 274 | 275 | 276 | def read_npy(src): 277 | arr = np.load(src) 278 | for i in arr: 279 | print(i) 280 | 281 | 282 | def parse_args(): 283 | parser = argparse.ArgumentParser(usage='python3 tools.py -i path/to/train.csv -r path/to/checkpoint') 284 | parser.add_argument('-i', '--data_path', help='path to your dataset index file') 285 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint', default=None) 286 | args = parser.parse_args() 287 | return args 288 | 289 | 290 | if __name__ == '__main__': 291 | args = parse_args() 292 | # xcp = '/home/asus/Code/checkpoint/ff/xcept/nb-model_type-xception_ep-19.pth' 293 | # meso = '' 294 | # msi = '/home/asus/Code/checkpoint/ff/msin/weights.h5' 295 | # cap = '/home/asus/Code/checkpoint/ff/cap/capsule_18.pt' 296 | # model_pred('/home/asus/ffdf/test/1', 'xception', xcp) 297 | # model_pred('/home/asus/ffdf_40/test/1', 'cap', cap) 298 | # model_pred('/home/asus/ffdf/test/1', 'msi', msi) 299 | # model_pred('/home/asus/ffdf_40/test/0', 'msi', msi) 300 | read_npy('/Users/pu/Downloads/images/c23/0/tcap.txt.npy') 301 | read_npy('/Users/pu/Downloads/images/c23/0/tmsi.txt.npy') 302 | read_npy('/Users/pu/Downloads/images/c23/0/txcep.txt.npy') 303 | print('=================') 304 | read_npy('/Users/pu/Downloads/images/c23/1/tcap.txt.npy') 305 | read_npy('/Users/pu/Downloads/images/c23/1/tmsi.txt.npy') 306 | read_npy('/Users/pu/Downloads/images/c23/1/txcep.txt.npy') 307 | print('=================') 308 | read_npy('/Users/pu/Downloads/images/c40/0/tcap.txt.npy') 309 | read_npy('/Users/pu/Downloads/images/c40/0/tmsi.txt.npy') 310 | read_npy('/Users/pu/Downloads/images/c40/0/txcep.txt.npy') 311 | print('=================') 312 | read_npy('/Users/pu/Downloads/images/c40/1/tcap.txt.npy') 313 | read_npy('/Users/pu/Downloads/images/c40/1/tmsi.txt.npy') 314 | read_npy('/Users/pu/Downloads/images/c40/1/txcep.txt.npy') 315 | -------------------------------------------------------------------------------- /utils/train_cpvr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | from tensorboardX import SummaryWriter 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | 9 | from dataset import Dataset 10 | from templates import get_templates 11 | 12 | MODEL_DIR = './models/' 13 | BACKBONE = 'xcp' 14 | MAPTYPE = 'reg' 15 | BATCH_SIZE = 15 16 | MAX_EPOCHS = 100 17 | STEPS_PER_EPOCH = 1000 18 | LEARNING_RATE = 0.0001 19 | WEIGHT_DECAY = 0.001 20 | 21 | CONFIGS = { 22 | 'xcp': { 23 | 'img_size': (299, 299), 24 | 'map_size': (19, 19), 25 | 'norms': [[0.5] * 3, [0.5] * 3] 26 | }, 27 | 'vgg': { 28 | 'img_size': (299, 299), 29 | 'map_size': (19, 19), 30 | 'norms': [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] 31 | } 32 | } 33 | CONFIG = CONFIGS[BACKBONE] 34 | 35 | if BACKBONE == 'xcp': 36 | from xception import Model 37 | elif BACKBONE == 'vgg': 38 | from vgg import Model 39 | 40 | torch.backends.deterministic = True 41 | SEED = 1 42 | random.seed(SEED) 43 | torch.manual_seed(SEED) 44 | torch.cuda.manual_seed_all(SEED) 45 | 46 | DATA_TRAIN = Dataset('train', BATCH_SIZE, CONFIG['img_size'], CONFIG['map_size'], CONFIG['norms'], SEED) 47 | 48 | DATA_EVAL = Dataset('eval', BATCH_SIZE, CONFIG['img_size'], CONFIG['map_size'], CONFIG['norms'], SEED) 49 | 50 | TEMPLATES = None 51 | if MAPTYPE in ['tmp', 'pca_tmp']: 52 | TEMPLATES = get_templates() 53 | 54 | MODEL_NAME = '{0}_{1}'.format(BACKBONE, MAPTYPE) 55 | MODEL_DIR = MODEL_DIR + MODEL_NAME + '/' 56 | 57 | MODEL = Model(MAPTYPE, TEMPLATES, 2, False) 58 | 59 | OPTIM = optim.Adam(MODEL.model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) 60 | MODEL.model.cuda() 61 | LOSS_CSE = nn.CrossEntropyLoss().cuda() 62 | LOSS_L1 = nn.L1Loss().cuda() 63 | MAXPOOL = nn.MaxPool2d(19).cuda() 64 | 65 | 66 | def calculate_losses(batch): 67 | img = batch['img'] 68 | msk = batch['msk'] 69 | lab = batch['lab'] 70 | x, mask, vec = MODEL.model(img) 71 | loss_l1 = LOSS_L1(mask, msk) 72 | loss_cse = LOSS_CSE(x, lab) 73 | loss = loss_l1 + loss_cse 74 | pred = torch.max(x, dim=1)[1] 75 | acc = (pred == lab).float().mean() 76 | return {'loss': loss, 'loss_l1': loss_l1, 'loss_cse': loss_cse, 'acc': acc} 77 | 78 | 79 | def process_batch(batch, mode): 80 | if mode == 'train': 81 | MODEL.model.train() 82 | losses = calculate_losses(batch) 83 | OPTIM.zero_grad() 84 | losses['loss'].backward() 85 | OPTIM.step() 86 | elif mode == 'eval': 87 | MODEL.model.eval() 88 | with torch.no_grad(): 89 | losses = calculate_losses(batch) 90 | return losses 91 | 92 | 93 | SUMMARY_WRITER = SummaryWriter(MODEL_DIR + 'logs/') 94 | 95 | 96 | def write_tfboard(item, itr, name): 97 | SUMMARY_WRITER.add_scalar('{0}'.format(name), item, itr) 98 | 99 | 100 | def run_step(e, s): 101 | batch = DATA_TRAIN.get_batch() 102 | losses = process_batch(batch, 'train') 103 | 104 | if s % 10 == 0: 105 | print('\r{0} - '.format(s) + ', '.join( 106 | ['{0}: {1:.3f}'.format(_, losses[_].cpu().detach().numpy()) for _ in losses]), end='') 107 | if s % 100 == 0: 108 | print('\n', end='') 109 | [write_tfboard(losses[_], e * STEPS_PER_EPOCH + s, _) for _ in losses] 110 | 111 | 112 | def run_epoch(e): 113 | print('Epoch: {0}'.format(e)) 114 | for s in range(STEPS_PER_EPOCH): 115 | run_step(e, s) 116 | MODEL.save(e + 1, OPTIM, MODEL_DIR) 117 | 118 | 119 | LAST_EPOCH = 0 120 | for e in range(LAST_EPOCH, MAX_EPOCHS): 121 | run_epoch(e) 122 | -------------------------------------------------------------------------------- /utils/val.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | "720", 4 | "672" 5 | ], 6 | [ 7 | "939", 8 | "115" 9 | ], 10 | [ 11 | "284", 12 | "263" 13 | ], 14 | [ 15 | "402", 16 | "453" 17 | ], 18 | [ 19 | "820", 20 | "818" 21 | ], 22 | [ 23 | "762", 24 | "832" 25 | ], 26 | [ 27 | "834", 28 | "852" 29 | ], 30 | [ 31 | "922", 32 | "898" 33 | ], 34 | [ 35 | "104", 36 | "126" 37 | ], 38 | [ 39 | "106", 40 | "198" 41 | ], 42 | [ 43 | "159", 44 | "175" 45 | ], 46 | [ 47 | "416", 48 | "342" 49 | ], 50 | [ 51 | "857", 52 | "909" 53 | ], 54 | [ 55 | "599", 56 | "585" 57 | ], 58 | [ 59 | "443", 60 | "514" 61 | ], 62 | [ 63 | "566", 64 | "617" 65 | ], 66 | [ 67 | "472", 68 | "511" 69 | ], 70 | [ 71 | "325", 72 | "492" 73 | ], 74 | [ 75 | "816", 76 | "649" 77 | ], 78 | [ 79 | "583", 80 | "558" 81 | ], 82 | [ 83 | "933", 84 | "925" 85 | ], 86 | [ 87 | "419", 88 | "824" 89 | ], 90 | [ 91 | "465", 92 | "482" 93 | ], 94 | [ 95 | "565", 96 | "589" 97 | ], 98 | [ 99 | "261", 100 | "254" 101 | ], 102 | [ 103 | "992", 104 | "980" 105 | ], 106 | [ 107 | "157", 108 | "245" 109 | ], 110 | [ 111 | "571", 112 | "746" 113 | ], 114 | [ 115 | "947", 116 | "951" 117 | ], 118 | [ 119 | "926", 120 | "900" 121 | ], 122 | [ 123 | "493", 124 | "538" 125 | ], 126 | [ 127 | "468", 128 | "470" 129 | ], 130 | [ 131 | "915", 132 | "895" 133 | ], 134 | [ 135 | "362", 136 | "354" 137 | ], 138 | [ 139 | "440", 140 | "364" 141 | ], 142 | [ 143 | "640", 144 | "638" 145 | ], 146 | [ 147 | "827", 148 | "817" 149 | ], 150 | [ 151 | "793", 152 | "768" 153 | ], 154 | [ 155 | "837", 156 | "890" 157 | ], 158 | [ 159 | "004", 160 | "982" 161 | ], 162 | [ 163 | "192", 164 | "134" 165 | ], 166 | [ 167 | "745", 168 | "777" 169 | ], 170 | [ 171 | "299", 172 | "145" 173 | ], 174 | [ 175 | "742", 176 | "775" 177 | ], 178 | [ 179 | "586", 180 | "223" 181 | ], 182 | [ 183 | "483", 184 | "370" 185 | ], 186 | [ 187 | "779", 188 | "794" 189 | ], 190 | [ 191 | "971", 192 | "564" 193 | ], 194 | [ 195 | "273", 196 | "807" 197 | ], 198 | [ 199 | "991", 200 | "064" 201 | ], 202 | [ 203 | "664", 204 | "668" 205 | ], 206 | [ 207 | "823", 208 | "584" 209 | ], 210 | [ 211 | "656", 212 | "666" 213 | ], 214 | [ 215 | "557", 216 | "560" 217 | ], 218 | [ 219 | "471", 220 | "455" 221 | ], 222 | [ 223 | "042", 224 | "084" 225 | ], 226 | [ 227 | "979", 228 | "875" 229 | ], 230 | [ 231 | "316", 232 | "369" 233 | ], 234 | [ 235 | "091", 236 | "116" 237 | ], 238 | [ 239 | "023", 240 | "923" 241 | ], 242 | [ 243 | "702", 244 | "612" 245 | ], 246 | [ 247 | "904", 248 | "046" 249 | ], 250 | [ 251 | "647", 252 | "622" 253 | ], 254 | [ 255 | "958", 256 | "956" 257 | ], 258 | [ 259 | "606", 260 | "567" 261 | ], 262 | [ 263 | "632", 264 | "548" 265 | ], 266 | [ 267 | "927", 268 | "912" 269 | ], 270 | [ 271 | "350", 272 | "349" 273 | ], 274 | [ 275 | "595", 276 | "597" 277 | ], 278 | [ 279 | "727", 280 | "729" 281 | ] 282 | ] -------------------------------------------------------------------------------- /utils/xcp_reg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import os 6 | import sys 7 | 8 | 9 | class SeparableConv2d(nn.Module): 10 | def __init__(self, c_in, c_out, ks, stride=1, padding=0, dilation=1, bias=False): 11 | super(SeparableConv2d, self).__init__() 12 | self.c = nn.Conv2d(c_in, c_in, ks, stride, padding, dilation, groups=c_in, bias=bias) 13 | self.pointwise = nn.Conv2d(c_in, c_out, 1, 1, 0, 1, 1, bias=bias) 14 | 15 | def forward(self, x): 16 | x = self.c(x) 17 | x = self.pointwise(x) 18 | return x 19 | 20 | 21 | class Block(nn.Module): 22 | def __init__(self, c_in, c_out, reps, stride=1, start_with_relu=True, grow_first=True): 23 | super(Block, self).__init__() 24 | 25 | self.skip = None 26 | self.skip_bn = None 27 | if c_out != c_in or stride != 1: 28 | self.skip = nn.Conv2d(c_in, c_out, 1, stride=stride, bias=False) 29 | self.skip_bn = nn.BatchNorm2d(c_out) 30 | 31 | self.relu = nn.ReLU(inplace=True) 32 | 33 | rep = [] 34 | c = c_in 35 | if grow_first: 36 | rep.append(self.relu) 37 | rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)) 38 | rep.append(nn.BatchNorm2d(c_out)) 39 | c = c_out 40 | 41 | for i in range(reps - 1): 42 | rep.append(self.relu) 43 | rep.append(SeparableConv2d(c, c, 3, stride=1, padding=1, bias=False)) 44 | rep.append(nn.BatchNorm2d(c)) 45 | 46 | if not grow_first: 47 | rep.append(self.relu) 48 | rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)) 49 | rep.append(nn.BatchNorm2d(c_out)) 50 | 51 | if not start_with_relu: 52 | rep = rep[1:] 53 | else: 54 | rep[0] = nn.ReLU(inplace=False) 55 | 56 | if stride != 1: 57 | rep.append(nn.MaxPool2d(3, stride, 1)) 58 | self.rep = nn.Sequential(*rep) 59 | 60 | def forward(self, inp): 61 | x = self.rep(inp) 62 | 63 | if self.skip is not None: 64 | y = self.skip(inp) 65 | y = self.skip_bn(y) 66 | else: 67 | y = inp 68 | 69 | x += y 70 | return x 71 | 72 | 73 | class RegressionMap(nn.Module): 74 | def __init__(self, c_in): 75 | super(RegressionMap, self).__init__() 76 | self.c = SeparableConv2d(c_in, 1, 3, stride=1, padding=1, bias=False) 77 | self.s = nn.Sigmoid() 78 | 79 | def forward(self, x): 80 | mask = self.c(x) 81 | mask = self.s(mask) 82 | return mask, None 83 | 84 | 85 | class TemplateMap(nn.Module): 86 | def __init__(self, c_in, templates): 87 | super(TemplateMap, self).__init__() 88 | self.c = Block(c_in, 364, 2, 2, start_with_relu=True, grow_first=False) 89 | self.l = nn.Linear(364, 10) 90 | self.relu = nn.ReLU(inplace=True) 91 | 92 | self.templates = templates 93 | 94 | def forward(self, x): 95 | v = self.c(x) 96 | v = self.relu(v) 97 | v = F.adaptive_avg_pool2d(v, (1, 1)) 98 | v = v.view(v.size(0), -1) 99 | v = self.l(v) 100 | mask = torch.mm(v, self.templates.reshape(10, 361)) 101 | mask = mask.reshape(x.shape[0], 1, 19, 19) 102 | 103 | return mask, v 104 | 105 | 106 | class PCATemplateMap(nn.Module): 107 | def __init__(self, templates): 108 | super(PCATemplateMap, self).__init__() 109 | self.templates = templates 110 | 111 | def forward(self, x): 112 | fe = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) 113 | fe = torch.transpose(fe, 1, 2) 114 | mu = torch.mean(fe, 2, keepdim=True) 115 | fea_diff = fe - mu 116 | 117 | cov_fea = torch.bmm(fea_diff, torch.transpose(fea_diff, 1, 2)) 118 | B = self.templates.reshape(1, 10, 361).repeat(x.shape[0], 1, 1) 119 | D = torch.bmm(torch.bmm(B, cov_fea), torch.transpose(B, 1, 2)) 120 | eigen_value, eigen_vector = D.symeig(eigenvectors=True) 121 | index = torch.tensor([9]).cuda() 122 | eigen = torch.index_select(eigen_vector, 2, index) 123 | 124 | v = eigen.squeeze(-1) 125 | mask = torch.mm(v, self.templates.reshape(10, 361)) 126 | mask = mask.reshape(x.shape[0], 1, 19, 19) 127 | return mask, v 128 | 129 | 130 | class Xception(nn.Module): 131 | """ 132 | Xception optimized for the ImageNet dataset, as specified in 133 | https://arxiv.org/pdf/1610.02357.pdf 134 | """ 135 | 136 | def __init__(self, maptype, templates, num_classes=1000): 137 | super(Xception, self).__init__() 138 | self.num_classes = num_classes 139 | 140 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) 141 | self.bn1 = nn.BatchNorm2d(32) 142 | self.relu = nn.ReLU(inplace=True) 143 | 144 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False) 145 | self.bn2 = nn.BatchNorm2d(64) 146 | 147 | self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) 148 | self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) 149 | self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) 150 | self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 151 | self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 152 | self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 153 | self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 154 | self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 155 | self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 156 | self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 157 | self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 158 | self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) 159 | 160 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) 161 | self.bn3 = nn.BatchNorm2d(1536) 162 | 163 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) 164 | self.bn4 = nn.BatchNorm2d(2048) 165 | 166 | self.last_linear = nn.Linear(2048, num_classes) 167 | 168 | if maptype == 'none': 169 | self.map = [1, None] 170 | elif maptype == 'reg': 171 | self.map = RegressionMap(728) 172 | elif maptype == 'tmp': 173 | self.map = TemplateMap(728) 174 | elif maptype == 'pca_tmp': 175 | self.map = PCATemplateMap(728) 176 | else: 177 | print('Unknown map type: `{0}`'.format(maptype)) 178 | sys.exit() 179 | 180 | def features(self, input): 181 | x = self.conv1(input) 182 | x = self.bn1(x) 183 | x = self.relu(x) 184 | 185 | x = self.conv2(x) 186 | x = self.bn2(x) 187 | x = self.relu(x) 188 | 189 | x = self.block1(x) 190 | x = self.block2(x) 191 | x = self.block3(x) 192 | x = self.block4(x) 193 | x = self.block5(x) 194 | x = self.block6(x) 195 | x = self.block7(x) 196 | mask, vec = self.map(x) 197 | x = x * mask 198 | x = self.block8(x) 199 | x = self.block9(x) 200 | x = self.block10(x) 201 | x = self.block11(x) 202 | x = self.block12(x) 203 | x = self.conv3(x) 204 | x = self.bn3(x) 205 | x = self.relu(x) 206 | 207 | x = self.conv4(x) 208 | x = self.bn4(x) 209 | return x, mask, vec 210 | 211 | def logits(self, features): 212 | x = self.relu(features) 213 | x = F.adaptive_avg_pool2d(x, (1, 1)) 214 | x = x.view(x.size(0), -1) 215 | x = self.last_linear(x) 216 | return x 217 | 218 | def forward(self, input): 219 | x, mask, vec = self.features(input) 220 | x = self.logits(x) 221 | return x, mask, vec 222 | 223 | 224 | def init_weights(m): 225 | classname = m.__class__.__name__ 226 | if classname.find('SeparableConv2d') != -1: 227 | m.c.weight.data.normal_(0.0, 0.01) 228 | if m.c.bias is not None: 229 | m.c.bias.data.fill_(0) 230 | m.pointwise.weight.data.normal_(0.0, 0.01) 231 | if m.pointwise.bias is not None: 232 | m.pointwise.bias.data.fill_(0) 233 | elif classname.find('Conv') != -1 or classname.find('Linear') != -1: 234 | m.weight.data.normal_(0.0, 0.01) 235 | if m.bias is not None: 236 | m.bias.data.fill_(0) 237 | elif classname.find('BatchNorm') != -1: 238 | m.weight.data.normal_(1.0, 0.01) 239 | m.bias.data.fill_(0) 240 | elif classname.find('LSTM') != -1: 241 | for i in m._parameters: 242 | if i.__class__.__name__.find('weight') != -1: 243 | i.data.normal_(0.0, 0.01) 244 | elif i.__class__.__name__.find('bias') != -1: 245 | i.bias.data.fill_(0) 246 | 247 | 248 | class Model: 249 | def __init__(self, maptype='None', templates=None, num_classes=2, load_pretrain=True): 250 | model = Xception(maptype, templates, num_classes=num_classes) 251 | if load_pretrain: 252 | state_dict = torch.load('./xception-b5690688.pth') 253 | for name, weights in state_dict: 254 | if 'pointwise' in name: 255 | state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) 256 | del state_dict['fc.weight'] 257 | del state_dict['fc.bias'] 258 | model.load_state_dict(state_dict, False) 259 | else: 260 | model.apply(init_weights) 261 | self.model = model 262 | 263 | def save(self, epoch, optim, model_dir): 264 | state = {'net': self.model.state_dict(), 'optim': optim.state_dict()} 265 | torch.save(state, '{0}/{1:06d}.tar'.format(model_dir, epoch)) 266 | print('Saved model `{0}`'.format(epoch)) 267 | 268 | def load(self, epoch, model_dir): 269 | filename = '{0}{1:06d}.tar'.format(model_dir, epoch) 270 | print('Loading model from {0}'.format(filename)) 271 | if os.path.exists(filename): 272 | state = torch.load(filename) 273 | self.model.load_state_dict(state['net']) 274 | else: 275 | print('Failed to load model from {0}'.format(filename)) 276 | -------------------------------------------------------------------------------- /xception/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/xception/__init__.py -------------------------------------------------------------------------------- /xception/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Author: Andreas Rössler 4 | """ 5 | import os, sys 6 | # sys.path.append('../') 7 | import argparse 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from xception.xception import xception 13 | import math 14 | import torchvision 15 | 16 | 17 | def return_pytorch04_xception(pretrained=True): 18 | # Raises warning "src not broadcastable to dst" but thats fine 19 | model = xception(pretrained=False) 20 | if pretrained: 21 | # Load model in torch 0.4+ 22 | model.fc = model.last_linear 23 | del model.last_linear 24 | # import pdb; pdb.set_trace() 25 | state_dict = torch.load(os.path.dirname(__file__) + '/xception.pth') 26 | # './trained_model/xception.pth') 27 | for name, weights in state_dict.items(): 28 | if 'pointwise' in name: 29 | state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) 30 | model.load_state_dict(state_dict) 31 | model.last_linear = model.fc 32 | del model.fc 33 | return model 34 | 35 | 36 | class TransferModel(nn.Module): 37 | """ 38 | Simple transfer learning model that takes an imagenet pretrained model with 39 | a fc layer as base model and retrains a new fc layer for num_out_classes 40 | """ 41 | 42 | def __init__(self, modelchoice, num_out_classes=2, dropout=0.0): 43 | super(TransferModel, self).__init__() 44 | self.modelchoice = modelchoice 45 | if modelchoice == 'xception': 46 | self.model = return_pytorch04_xception() 47 | # Replace fc 48 | num_ftrs = self.model.last_linear.in_features 49 | if not dropout: 50 | self.model.last_linear = nn.Linear(num_ftrs, num_out_classes) 51 | else: 52 | print('Using dropout', dropout) 53 | self.model.last_linear = nn.Sequential( 54 | nn.Dropout(p=dropout), 55 | nn.Linear(num_ftrs, num_out_classes) 56 | ) 57 | elif modelchoice == 'resnet50' or modelchoice == 'resnet18': 58 | if modelchoice == 'resnet50': 59 | self.model = torchvision.models.resnet50(pretrained=True) 60 | if modelchoice == 'resnet18': 61 | self.model = torchvision.models.resnet18(pretrained=True) 62 | # Replace fc 63 | num_ftrs = self.model.fc.in_features 64 | if not dropout: 65 | self.model.fc = nn.Linear(num_ftrs, num_out_classes) 66 | else: 67 | self.model.fc = nn.Sequential( 68 | nn.Dropout(p=dropout), 69 | nn.Linear(num_ftrs, num_out_classes) 70 | ) 71 | else: 72 | raise Exception('Choose valid model, e.g. resnet50') 73 | 74 | def set_trainable_up_to(self, boolean, layername="Conv2d_4a_3x3"): 75 | """ 76 | Freezes all layers below a specific layer and sets the following layers 77 | to true if boolean else only the fully connected final layer 78 | :param boolean: 79 | :param layername: depends on network, for inception e.g. Conv2d_4a_3x3 80 | :return: 81 | """ 82 | # Stage-1: freeze all the layers 83 | if layername is None: 84 | for i, param in self.model.named_parameters(): 85 | param.requires_grad = True 86 | return 87 | else: 88 | for i, param in self.model.named_parameters(): 89 | param.requires_grad = False 90 | if boolean: 91 | # Make all layers following the layername layer trainable 92 | ct = [] 93 | found = False 94 | for name, child in self.model.named_children(): 95 | if layername in ct: 96 | found = True 97 | for params in child.parameters(): 98 | params.requires_grad = True 99 | ct.append(name) 100 | if not found: 101 | raise Exception('Layer not found, cant finetune!'.format( 102 | layername)) 103 | else: 104 | if self.modelchoice == 'xception': 105 | # Make fc trainable 106 | for param in self.model.last_linear.parameters(): 107 | param.requires_grad = True 108 | 109 | else: 110 | # Make fc trainable 111 | for param in self.model.fc.parameters(): 112 | param.requires_grad = True 113 | 114 | def forward(self, x): 115 | x = self.model(x) 116 | return x 117 | 118 | 119 | def model_selection(modelname, num_out_classes, 120 | dropout=None): 121 | """ 122 | :param modelname: 123 | :return: model, image size, pretraining, input_list 124 | """ 125 | if modelname == 'xception': 126 | return TransferModel(modelchoice='xception', 127 | num_out_classes=num_out_classes), 299, \ 128 | True, ['image'], None 129 | elif modelname == 'resnet18': 130 | return TransferModel(modelchoice='resnet18', dropout=dropout, 131 | num_out_classes=num_out_classes), \ 132 | 224, True, ['image'], None 133 | else: 134 | raise NotImplementedError(modelname) 135 | 136 | 137 | if __name__ == '__main__': 138 | model, image_size, *_ = model_selection('resnet18', num_out_classes=2) 139 | print(model) 140 | model = model.cuda() 141 | from torchsummary import summary 142 | 143 | input_s = (3, image_size, image_size) 144 | print(summary(model, input_s)) 145 | -------------------------------------------------------------------------------- /xception/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 3 | 4 | @author: tstandley 5 | Adapted by cadene 6 | 7 | Creates an Xception Model as defined in: 8 | 9 | Francois Chollet 10 | Xception: Deep Learning with Depthwise Separable Convolutions 11 | https://arxiv.org/pdf/1610.02357.pdf 12 | 13 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 14 | 15 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 16 | 17 | REMEMBER to set your image size to 3x299x299 for both test and validation 18 | 19 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 20 | std=[0.5, 0.5, 0.5]) 21 | 22 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 23 | """ 24 | import math 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | import torch.utils.model_zoo as model_zoo 29 | from torch.nn import init 30 | 31 | pretrained_settings = { 32 | 'xception': { 33 | 'imagenet': { 34 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth', 35 | 'input_space': 'RGB', 36 | 'input_size': [3, 299, 299], 37 | 'input_range': [0, 1], 38 | 'mean': [0.5, 0.5, 0.5], 39 | 'std': [0.5, 0.5, 0.5], 40 | 'num_classes': 1000, 41 | 'scale': 0.8975 42 | # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 43 | } 44 | } 45 | } 46 | 47 | 48 | class SeparableConv2d(nn.Module): 49 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 50 | super(SeparableConv2d, self).__init__() 51 | 52 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, 53 | bias=bias) 54 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 55 | 56 | def forward(self, x): 57 | x = self.conv1(x) 58 | x = self.pointwise(x) 59 | return x 60 | 61 | 62 | class Block(nn.Module): 63 | def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): 64 | super(Block, self).__init__() 65 | 66 | if out_filters != in_filters or strides != 1: 67 | self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) 68 | self.skipbn = nn.BatchNorm2d(out_filters) 69 | else: 70 | self.skip = None 71 | 72 | self.relu = nn.ReLU(inplace=True) 73 | rep = [] 74 | 75 | filters = in_filters 76 | if grow_first: 77 | rep.append(self.relu) 78 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 79 | rep.append(nn.BatchNorm2d(out_filters)) 80 | filters = out_filters 81 | 82 | for i in range(reps - 1): 83 | rep.append(self.relu) 84 | rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) 85 | rep.append(nn.BatchNorm2d(filters)) 86 | 87 | if not grow_first: 88 | rep.append(self.relu) 89 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 90 | rep.append(nn.BatchNorm2d(out_filters)) 91 | 92 | if not start_with_relu: 93 | rep = rep[1:] 94 | else: 95 | rep[0] = nn.ReLU(inplace=False) 96 | 97 | if strides != 1: 98 | rep.append(nn.MaxPool2d(3, strides, 1)) 99 | self.rep = nn.Sequential(*rep) 100 | 101 | def forward(self, inp): 102 | x = self.rep(inp) 103 | 104 | if self.skip is not None: 105 | skip = self.skip(inp) 106 | skip = self.skipbn(skip) 107 | else: 108 | skip = inp 109 | 110 | x += skip 111 | return x 112 | 113 | 114 | class Xception(nn.Module): 115 | """ 116 | Xception optimized for the ImageNet dataset, as specified in 117 | https://arxiv.org/pdf/1610.02357.pdf 118 | """ 119 | 120 | def __init__(self, num_classes=1000): 121 | """ Constructor 122 | Args: 123 | num_classes: number of classes 124 | """ 125 | super(Xception, self).__init__() 126 | self.num_classes = num_classes 127 | 128 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) 129 | self.bn1 = nn.BatchNorm2d(32) 130 | self.relu = nn.ReLU(inplace=True) 131 | 132 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False) 133 | self.bn2 = nn.BatchNorm2d(64) 134 | # do relu here 135 | 136 | self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) 137 | self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) 138 | self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) 139 | 140 | self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 141 | self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 142 | self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 143 | self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 144 | 145 | self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 146 | self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 147 | self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 148 | self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 149 | 150 | self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) 151 | 152 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) 153 | self.bn3 = nn.BatchNorm2d(1536) 154 | 155 | # do relu here 156 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) 157 | self.bn4 = nn.BatchNorm2d(2048) 158 | 159 | self.fc = nn.Linear(2048, num_classes) 160 | 161 | # #------- init weights -------- 162 | # for m in self.modules(): 163 | # if isinstance(m, nn.Conv2d): 164 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 165 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 166 | # elif isinstance(m, nn.BatchNorm2d): 167 | # m.weight.data.fill_(1) 168 | # m.bias.data.zero_() 169 | # #----------------------------- 170 | 171 | def features(self, input): 172 | x = self.conv1(input) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | 176 | x = self.conv2(x) 177 | x = self.bn2(x) 178 | x = self.relu(x) 179 | 180 | x = self.block1(x) 181 | x = self.block2(x) 182 | x = self.block3(x) 183 | x = self.block4(x) 184 | x = self.block5(x) 185 | x = self.block6(x) 186 | x = self.block7(x) 187 | x = self.block8(x) 188 | x = self.block9(x) 189 | x = self.block10(x) 190 | x = self.block11(x) 191 | x = self.block12(x) 192 | 193 | x = self.conv3(x) 194 | x = self.bn3(x) 195 | x = self.relu(x) 196 | 197 | x = self.conv4(x) 198 | x = self.bn4(x) 199 | return x 200 | 201 | def logits(self, features): 202 | x = self.relu(features) 203 | 204 | x = F.adaptive_avg_pool2d(x, (1, 1)) 205 | x = x.view(x.size(0), -1) 206 | x = self.last_linear(x) 207 | return x 208 | 209 | def forward(self, input): 210 | x = self.features(input) 211 | # eric 212 | # print('1', x.shape) 213 | x = self.logits(x) 214 | # eric 215 | # print('2', x.shape) 216 | return x 217 | 218 | 219 | def xception(num_classes=1000, pretrained='imagenet'): 220 | model = Xception(num_classes=num_classes) 221 | if pretrained: 222 | settings = pretrained_settings['xception'][pretrained] 223 | assert num_classes == settings['num_classes'], \ 224 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 225 | 226 | model = Xception(num_classes=num_classes) 227 | model.load_state_dict(model_zoo.load_url(settings['url'])) 228 | 229 | model.input_space = settings['input_space'] 230 | model.input_size = settings['input_size'] 231 | model.input_range = settings['input_range'] 232 | model.mean = settings['mean'] 233 | model.std = settings['std'] 234 | 235 | # TODO: ugly 236 | model.last_linear = model.fc 237 | del model.fc 238 | return model 239 | --------------------------------------------------------------------------------