├── README.md
├── capsule
├── __init__.py
├── model_big.py
├── test_binary_ffpp.py
└── train_capsule.py
├── datasets
├── __init__.py
├── cvfunctional.py
├── cvtransforms.py
├── dataloader_imagenet_dct.py
├── dataset_imagenet.py
├── dataset_imagenet_dct.py
├── dataset_imagenet_dct_cupy.py
├── imagenet2lmdb.py
└── vision.py
├── dct
├── __init__.py
├── imagenet
│ ├── __init__.py
│ ├── gate.py
│ ├── gumbel.py
│ ├── resnet.py
│ ├── resnet_autsubset_inputgate.py
│ └── resnet_resized.py
└── utils.py
├── demo
├── demo.py
└── video
│ └── id0_id1_0002.mp4
├── fwa
├── __init__.py
└── classifier.py
├── imgs
├── imbalanced performance.png
└── overview.png
├── main.py
├── make_train_test.py
├── meso
├── __init__.py
├── eval_meso.py
├── meso.py
└── train_mesonet.py
├── models
├── __init__.py
├── convGRU.py
├── convlstm.py
├── model.py
└── resnet.py
├── utils
├── __init__.py
├── auccur.py
├── aucloss.py
├── cam.py
├── config.py
├── dataloader.py
├── drawpics.py
├── eval.py
├── ff.py
├── focalloss.py
├── gradcam.py
├── mmod_human_face_detector.dat
├── test.json
├── tools.py
├── train.json
├── train_cpvr.py
├── val.json
└── xcp_reg.py
└── xception
├── __init__.py
├── models.py
└── xception.py
/README.md:
--------------------------------------------------------------------------------
1 | # Learning a Deep Dual-level Network for Robust DeepFake Detection
2 |
3 | [](https://www.python.org/)
4 |
5 | Wenbo Pu, Jing Hu, Xin Wang, Yuezun Li, Shu Hu, Bin Zhu, Rui Song, Qi Song, Xi Wu, Siwei Lyu
6 | _________________
7 |
8 | This repository is the official implementation of our paper "Learning a Deep Dual-level Network for Robust DeepFake Detection", which has been accepted by **Pattern Recognition**.
9 |
10 | ## Overview
11 |
12 | 
13 |
14 | ## Imbalanced Performance
15 |
16 |
17 |
18 |
19 | ## Info
20 |
21 | We provided our method, Xception6, FWA7, MesoNet8, Capsule9 and others to train and test in this repository. Xception and FWA can be train or test at `main.py` while the other methods can be found in their individual folders, such as Capsule in `capsule/`.
22 |
23 | Except the model proposed in our paper, we also provided many variants of our model, including VIT, ResVIT and DCTNet10 for replacement of ResNet, and CRNN for replacement of RNN.
24 |
25 | We also implemented Face X-ray for data-augumentation (it is not used in this paper, but we found that it can increase the performance), if you are interested in, go check `utils/dataloader.py`.
26 |
27 | The implementation of AUC loss proposed in our paper can be found in `utils/aucloss.py`.
28 |
29 | Our checkpoint can be found [here](https://drive.google.com/file/d/144ol1u4Kz4HwOsG3qvEeVqH8bpqCvaOU/view?usp=sharing).
30 |
31 |
32 | ## Requirements
33 |
34 | - Pytorch 1.4.0
35 | - Ubuntu 16.04
36 | - CUDA 10.0
37 | - Python 3.6
38 | - Dlib 19.0
39 |
40 | ## Usage
41 |
42 | - We provide a demo to show how our model work. See `demo/demo.py`
43 | ```shell
44 | python demo.py --restore_from restore_from -- path video path
45 | ```
46 |
47 | - To train and test a model, use
48 |
49 | ```shell
50 | python main.py -i input_path -r restore_from -g gpu_id
51 | ```
52 |
53 | - More parameters including the gamma of AUC loss can be found and adjusted in `main.py`.
54 |
55 | ## Training data preparation
56 |
57 | We provided a script to generate training and test data for this repository. Use `make_train_test.py`. This script can preprocess FaceForensics++, Celeb-DF and DFDC datasets using [MTCNN](https://github.com/ipazc/mtcnn) or [Dlib](https://github.com/davisking/dlib/).
58 |
59 |
60 | ## Citation
61 |
62 | Please kindly consider citing our paper in your publications.
63 |
64 | ```bib
65 | @article{PU2022108832,
66 | title = {Learning a deep dual-level network for robust DeepFake detection},
67 | journal = {Pattern Recognition},
68 | volume = {130},
69 | pages = {108832},
70 | year = {2022},
71 | issn = {0031-3203},
72 | doi = {https://doi.org/10.1016/j.patcog.2022.108832},
73 | url = {https://www.sciencedirect.com/science/article/pii/S0031320322003132},
74 | author = {Wenbo Pu and Jing Hu and Xin Wang and Yuezun Li and Shu Hu and Bin Zhu and Rui Song and Qi Song and Xi Wu and Siwei Lyu}
75 | }
76 | ```
77 | _________________
78 |
79 | ## Notice
80 |
81 | This repository is NOT for commecial use. It is provided "as it is" and we are not responsible for any subsequence of using this code.
82 |
83 |
84 | ## Thanks
85 |
86 | 6 [FaceForensics++ Learning to Detect Manipulated Facial Images](https://github.com/ondyari/FaceForensics)
87 | 7 [Exposing DeepFake Videos By Detecting Face Warping Artifacts](https://github.com/yuezunli/CVPRW2019_Face_Artifacts)
88 | 8 [MesoNet - a Compact Facial Video Forgery Detection Network](https://github.com/DariusAf/MesoNet)
89 | 9 [USE OF A CAPSULE NETWORK TO DETECT FAKE IMAGES AND VIDEOS](https://github.com/raohashim/DFD)
90 | 10 [Learning in the Frequency Domain](https://github.com/calmevtime/DCTNet)
91 |
--------------------------------------------------------------------------------
/capsule/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/capsule/__init__.py
--------------------------------------------------------------------------------
/capsule/model_big.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2018, National Institute of Informatics
3 | All rights reserved.
4 | Author: Huy H. Nguyen
5 | -----------------------------------------------------
6 | Script for Capsule-Forensics model
7 | """
8 |
9 | import sys
10 |
11 | sys.setrecursionlimit(15000)
12 | import torch
13 | import torch.nn.functional as F
14 | from torch import nn
15 | import torch.backends.cudnn as cudnn
16 | from torch.autograd import Variable
17 | import torchvision.models as models
18 |
19 | NO_CAPS = 10
20 |
21 |
22 | class StatsNet(nn.Module):
23 | def __init__(self):
24 | super(StatsNet, self).__init__()
25 |
26 | def forward(self, x):
27 | x = x.view(x.data.shape[0], x.data.shape[1], x.data.shape[2] * x.data.shape[3])
28 |
29 | mean = torch.mean(x, 2)
30 | std = torch.std(x, 2)
31 |
32 | return torch.stack((mean, std), dim=1)
33 |
34 |
35 | class View(nn.Module):
36 | def __init__(self, *shape):
37 | super(View, self).__init__()
38 | self.shape = shape
39 |
40 | def forward(self, input):
41 | return input.view(self.shape)
42 |
43 |
44 | class VggExtractor(nn.Module):
45 | def __init__(self, train=False):
46 | super(VggExtractor, self).__init__()
47 |
48 | self.vgg_1 = self.Vgg(models.vgg19(pretrained=True), 0, 18)
49 | if train:
50 | self.vgg_1.train(mode=True)
51 | self.freeze_gradient()
52 | else:
53 | self.vgg_1.eval()
54 |
55 | def Vgg(self, vgg, begin, end):
56 | features = nn.Sequential(*list(vgg.features.children())[begin:(end + 1)])
57 | return features
58 |
59 | def freeze_gradient(self, begin=0, end=9):
60 | for i in range(begin, end + 1):
61 | self.vgg_1[i].requires_grad = False
62 |
63 | def forward(self, input):
64 | return self.vgg_1(input)
65 |
66 |
67 | class FeatureExtractor(nn.Module):
68 | def __init__(self):
69 | super(FeatureExtractor, self).__init__()
70 |
71 | self.capsules = nn.ModuleList([
72 | nn.Sequential(
73 | nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1),
74 | nn.BatchNorm2d(64),
75 | nn.ReLU(),
76 | nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
77 | nn.BatchNorm2d(16),
78 | nn.ReLU(),
79 | StatsNet(),
80 |
81 | nn.Conv1d(2, 8, kernel_size=5, stride=2, padding=2),
82 | nn.BatchNorm1d(8),
83 | nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1),
84 | nn.BatchNorm1d(1),
85 | View(-1, 8),
86 | )
87 | for _ in range(NO_CAPS)]
88 | )
89 |
90 | def squash(self, tensor, dim):
91 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
92 | scale = squared_norm / (1 + squared_norm)
93 | return scale * tensor / (torch.sqrt(squared_norm))
94 |
95 | def forward(self, x):
96 | # outputs = [capsule(x.detach()) for capsule in self.capsules]
97 | # outputs = [capsule(x.clone()) for capsule in self.capsules]
98 | outputs = [capsule(x) for capsule in self.capsules]
99 | output = torch.stack(outputs, dim=-1)
100 |
101 | return self.squash(output, dim=-1)
102 |
103 |
104 | class RoutingLayer(nn.Module):
105 | def __init__(self, gpu_id, num_input_capsules, num_output_capsules, data_in, data_out, num_iterations):
106 | super(RoutingLayer, self).__init__()
107 |
108 | self.gpu_id = gpu_id
109 | self.num_iterations = num_iterations
110 | self.route_weights = nn.Parameter(torch.randn(num_output_capsules, num_input_capsules, data_out, data_in))
111 |
112 | def squash(self, tensor, dim):
113 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
114 | scale = squared_norm / (1 + squared_norm)
115 | return scale * tensor / (torch.sqrt(squared_norm))
116 |
117 | def forward(self, x, random, dropout):
118 | # x[b, data, in_caps]
119 |
120 | x = x.transpose(2, 1)
121 | # x[b, in_caps, data]
122 |
123 | if random:
124 | noise = Variable(0.01 * torch.randn(*self.route_weights.size()))
125 | if self.gpu_id >= 0:
126 | noise = noise.cuda(self.gpu_id)
127 | route_weights = self.route_weights + noise
128 | else:
129 | route_weights = self.route_weights
130 |
131 | priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None]
132 |
133 | # route_weights [out_caps , 1 , in_caps , data_out , data_in]
134 | # x [ 1 , b , in_caps , data_in , 1 ]
135 | # priors [out_caps , b , in_caps , data_out, 1 ]
136 |
137 | priors = priors.transpose(1, 0)
138 | # priors[b, out_caps, in_caps, data_out, 1]
139 |
140 | if dropout > 0.0:
141 | drop = Variable(torch.FloatTensor(*priors.size()).bernoulli(1.0 - dropout))
142 | if self.gpu_id >= 0:
143 | drop = drop.cuda(self.gpu_id)
144 | priors = priors * drop
145 |
146 | logits = Variable(torch.zeros(*priors.size()))
147 | # logits[b, out_caps, in_caps, data_out, 1]
148 |
149 | if self.gpu_id >= 0:
150 | logits = logits.cuda(self.gpu_id)
151 |
152 | num_iterations = self.num_iterations
153 |
154 | for i in range(num_iterations):
155 | probs = F.softmax(logits, dim=2)
156 | outputs = self.squash((probs * priors).sum(dim=2, keepdim=True), dim=3)
157 |
158 | if i != self.num_iterations - 1:
159 | delta_logits = priors * outputs
160 | logits = logits + delta_logits
161 |
162 | # outputs[b, out_caps, 1, data_out, 1]
163 | outputs = outputs.squeeze()
164 |
165 | if len(outputs.shape) == 3:
166 | outputs = outputs.transpose(2, 1).contiguous()
167 | else:
168 | outputs = outputs.unsqueeze_(dim=0).transpose(2, 1).contiguous()
169 | # outputs[b, data_out, out_caps]
170 |
171 | return outputs
172 |
173 |
174 | class CapsuleNet(nn.Module):
175 | def __init__(self, num_class, gpu_id):
176 | super(CapsuleNet, self).__init__()
177 |
178 | self.num_class = num_class
179 | self.fea_ext = FeatureExtractor()
180 | self.fea_ext.apply(self.weights_init)
181 |
182 | self.routing_stats = RoutingLayer(gpu_id=gpu_id, num_input_capsules=NO_CAPS, num_output_capsules=num_class,
183 | data_in=8, data_out=4, num_iterations=2)
184 |
185 | def weights_init(self, m):
186 | classname = m.__class__.__name__
187 | if classname.find('Conv') != -1:
188 | m.weight.data.normal_(0.0, 0.02)
189 | elif classname.find('BatchNorm') != -1:
190 | m.weight.data.normal_(1.0, 0.02)
191 | m.bias.data.fill_(0)
192 |
193 | def forward(self, x, random=False, dropout=0.0):
194 |
195 | z = self.fea_ext(x)
196 | z = self.routing_stats(z, random, dropout=dropout)
197 | # z[b, data, out_caps]
198 |
199 | # classes = F.softmax(z, dim=-1)
200 |
201 | # class_ = classes.detach()
202 | # class_ = class_.mean(dim=1)
203 |
204 | # return classes, class_
205 |
206 | classes = F.softmax(z, dim=-1)
207 | class_ = classes.detach()
208 | class_ = class_.mean(dim=1)
209 |
210 | return z, class_
211 |
212 |
213 | class CapsuleLoss(nn.Module):
214 | def __init__(self, gpu_id):
215 | super(CapsuleLoss, self).__init__()
216 | self.cross_entropy_loss = nn.CrossEntropyLoss()
217 |
218 | if gpu_id >= 0:
219 | self.cross_entropy_loss.cuda(gpu_id)
220 |
221 | def forward(self, classes, labels):
222 | loss_t = self.cross_entropy_loss(classes[:, 0, :], labels)
223 |
224 | for i in range(classes.size(1) - 1):
225 | loss_t = loss_t + self.cross_entropy_loss(classes[:, i + 1, :], labels)
226 |
227 | return loss_t
228 |
--------------------------------------------------------------------------------
/capsule/test_binary_ffpp.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019, National Institute of Informatics
3 | All rights reserved.
4 | Author: Huy H. Nguyen
5 | -----------------------------------------------------
6 | Script for testing Capsule-Forensics-v2 on FaceForensics++ database (Real, DeepFakes, Face2Face, FaceSwap)
7 | """
8 |
9 | import sys
10 | sys.setrecursionlimit(15000)
11 | import os
12 | import torch
13 | import torch.backends.cudnn as cudnn
14 | import numpy as np
15 | from torch.autograd import Variable
16 | import torch.utils.data
17 | import torchvision.datasets as dset
18 | from torch.utils.data import DataLoader
19 | import torchvision.transforms as transforms
20 | from tqdm import tqdm
21 | import argparse
22 | from sklearn import metrics
23 | from scipy.optimize import brentq
24 | from scipy.interpolate import interp1d
25 | from sklearn.metrics import roc_curve
26 | import model_big
27 | import pandas
28 |
29 | from utils.dataloader import FrameDataset
30 |
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--dataset', default ='databases/faceforensicspp', help='path to dataset')
33 | parser.add_argument('--test_set', default ='test', help='test set')
34 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=0)
35 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
36 | parser.add_argument('--imageSize', type=int, default=300, help='the height / width of the input image to network')
37 | parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID')
38 | parser.add_argument('--outf', default='checkpoints/binary_faceforensicspp', help='folder to output model checkpoints')
39 | parser.add_argument('--random', action='store_true', default=False, help='enable randomness for routing matrix')
40 | parser.add_argument('--id', type=int, default=21, help='checkpoint ID')
41 |
42 | opt = parser.parse_args()
43 | print(opt)
44 |
45 | if __name__ == '__main__':
46 |
47 | # text_writer = open(os.path.join(opt.outf, 'test.txt'), 'w')
48 |
49 | transform_fwd = transforms.Compose([
50 | transforms.Resize((opt.imageSize, opt.imageSize)),
51 | transforms.ToTensor(),
52 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
53 | ])
54 |
55 |
56 | # dataset_test = dset.ImageFolder(root=os.path.join(opt.dataset, opt.test_set), transform=transform_fwd)
57 | # assert dataset_test
58 | # dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=opt.batchSize, shuffle=False, num_workers=int(opt.workers))
59 | # dataloaders = {}
60 | # for name in ['train', 'test']:
61 | # raw_data = pandas.read_csv(os.path.join(opt.dataset, '%s.csv' % name))
62 | # dataloaders[name] = DataLoader(FrameDataset(raw_data.to_numpy()), **config.dataset_params)
63 | raw_data = pandas.read_csv(os.path.join(opt.dataset, 'test.csv'))
64 | dataloader_test = DataLoader(FrameDataset(raw_data.to_numpy()),
65 | batch_size=opt.batchSize,
66 | shuffle=True,
67 | num_workers=4,
68 | pin_memory=False)
69 | vgg_ext = model_big.VggExtractor()
70 | capnet = model_big.CapsuleNet(2, opt.gpu_id)
71 |
72 | capnet.load_state_dict(torch.load(os.path.join(opt.outf)))
73 | capnet.eval()
74 |
75 | if opt.gpu_id >= 0:
76 | vgg_ext.cuda(opt.gpu_id)
77 | capnet.cuda(opt.gpu_id)
78 |
79 |
80 | ##################################################################################
81 |
82 | tol_label = np.array([], dtype=np.float)
83 | tol_pred = np.array([], dtype=np.float)
84 | tol_pred_prob = np.array([], dtype=np.float)
85 |
86 | count = 0
87 | loss_test = 0
88 |
89 | for img_data, labels_data in tqdm(dataloader_test):
90 |
91 | labels_data[labels_data > 1] = 1
92 | img_label = labels_data.numpy().astype(np.float)
93 |
94 | if opt.gpu_id >= 0:
95 | img_data = img_data.cuda(opt.gpu_id)
96 | labels_data = labels_data.cuda(opt.gpu_id)
97 |
98 | input_v = Variable(img_data)
99 |
100 | x = vgg_ext(input_v)
101 | classes, class_ = capnet(x, random=opt.random)
102 |
103 | output_dis = class_.data.cpu()
104 | output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)
105 |
106 | for i in range(output_dis.shape[0]):
107 | if output_dis[i,1] >= output_dis[i,0]:
108 | output_pred[i] = 1.0
109 | else:
110 | output_pred[i] = 0.0
111 |
112 | tol_label = np.concatenate((tol_label, img_label))
113 | tol_pred = np.concatenate((tol_pred, output_pred))
114 |
115 | pred_prob = torch.softmax(output_dis, dim=1)
116 | tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:, 1].data.numpy()))
117 |
118 | count += 1
119 |
120 | acc_test = metrics.accuracy_score(tol_label, tol_pred)
121 | auc_test = metrics.roc_auc_score(tol_label, tol_pred_prob)
122 | f1_test = metrics.f1_score(tol_label, tol_pred)
123 | recall_test = metrics.recall_score(tol_label, tol_pred)
124 | precision = metrics.precision_score(tol_label, tol_pred)
125 | loss_test /= count
126 |
127 | fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
128 | np.save('./m/cap/f_fpr.npy', fpr)
129 | np.save('./m/cap/f_tpr.npy', tpr)
130 | # eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
131 | #
132 | # fnr = 1 - tpr
133 | # hter = (fpr + fnr)/2
134 | # print('[Epoch %d] Train loss: %.4f acc: %.2f | Test loss: %.4f acc: %.2f auc: %.2f'
135 | # % (opt.id, acc_test * 100, loss_test, acc_test * 100, auc_test * 100))
136 | print('[Epoch %d] Test acc: %.2f AUC: %.2f f1: %.2f recall:%.2f precision:%.2f'
137 | % (opt.id, acc_test*100, auc_test*100, f1_test, recall_test, precision))
138 | # text_writer.write('%d,%.2f,%.2f\n'% (opt.id, acc_test*100, eer*100))
139 | #
140 | # text_writer.flush()
141 | # text_writer.close()
142 |
--------------------------------------------------------------------------------
/capsule/train_capsule.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019, National Institute of Informatics
3 | All rights reserved.
4 | Author: Huy H. Nguyen
5 | -----------------------------------------------------
6 | Script for training Capsule-Forensics-v2 on FaceForensics++ database (Real, DeepFakes, Face2Face, FaceSwap)
7 | """
8 |
9 | import sys
10 |
11 | sys.setrecursionlimit(15000)
12 | import os
13 | import random
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | import numpy as np
17 | from torch.autograd import Variable
18 | from torch.optim import Adam
19 | # import torchvision.transforms as transforms
20 | from torch.utils.data import DataLoader
21 | from tqdm import tqdm
22 | import argparse
23 | from sklearn import metrics
24 | import model_big
25 | import pandas
26 |
27 | from utils.dataloader import FrameDataset
28 |
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--dataset', default='', help='path to root dataset')
32 | # parser.add_argument('--train_set', default='train', help='train set')
33 | # parser.add_argument('--val_set', default='validation', help='validation set')
34 | # parser.add_argument('--workers', type=int, help='number of data loading workers', default=0)
35 | parser.add_argument('--batch_size', type=int, default=32, help='batch size')
36 | # parser.add_argument('--imageSize', type=int, default=300, help='the height / width of the input image to network')
37 | parser.add_argument('--niter', type=int, default=20, help='number of epochs to train for')
38 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
39 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam')
40 | parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID')
41 | parser.add_argument('--resume', type=int, default=0, help="choose a epochs to resume from (0 to train from scratch)")
42 | parser.add_argument('--outf', default='capsule/binary_faceforensicspp', help='folder to output model checkpoints')
43 | parser.add_argument('--disable_random', action='store_true', default=False,
44 | help='disable randomness for routing matrix')
45 | parser.add_argument('--dropout', type=float, default=0.05, help='dropout percentage')
46 | parser.add_argument('--manualSeed', type=int, help='manual seed')
47 |
48 | opt = parser.parse_args()
49 | print(opt)
50 |
51 | opt.random = not opt.disable_random
52 |
53 | if __name__ == "__main__":
54 |
55 | if opt.manualSeed is None:
56 | opt.manualSeed = random.randint(1, 10000)
57 | print("Random Seed: ", opt.manualSeed)
58 | random.seed(opt.manualSeed)
59 | torch.manual_seed(opt.manualSeed)
60 |
61 | if opt.gpu_id >= 0:
62 | torch.cuda.manual_seed_all(opt.manualSeed)
63 | cudnn.benchmark = True
64 |
65 | if opt.resume > 0:
66 | text_writer = open('train_capsule.csv', 'a')
67 | else:
68 | text_writer = open('train_capsule.csv', 'w')
69 |
70 | vgg_ext = model_big.VggExtractor()
71 | capnet = model_big.CapsuleNet(2, opt.gpu_id)
72 | capsule_loss = model_big.CapsuleLoss(opt.gpu_id)
73 |
74 | if opt.gpu_id >= 0:
75 | capnet.cuda(opt.gpu_id)
76 | vgg_ext.cuda(opt.gpu_id)
77 | capsule_loss.cuda(opt.gpu_id)
78 |
79 | capnet.load_state_dict(torch.load('/home/asus/Code/pvc/capsule_8.pt'))
80 |
81 | optimizer = Adam(capnet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
82 |
83 | if opt.resume > 0:
84 | capnet.load_state_dict(torch.load(os.path.join(opt.outf, 'capsule_' + str(opt.resume) + '.pt')))
85 | capnet.train(mode=True)
86 | optimizer.load_state_dict(torch.load(os.path.join(opt.outf, 'optim_' + str(opt.resume) + '.pt')))
87 |
88 | if opt.gpu_id >= 0:
89 | for state in optimizer.state.values():
90 | for k, v in state.items():
91 | if isinstance(v, torch.Tensor):
92 | state[k] = v.cuda(opt.gpu_id)
93 |
94 |
95 | #
96 | # transform_fwd = transforms.Compose([
97 | # transforms.Resize((opt.imageSize, opt.imageSize)),
98 | # transforms.ToTensor(),
99 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
100 | # ])
101 |
102 | # dataset_train = dset.ImageFolder(root=os.path.join(opt.dataset, opt.train_set), transform=transform_fwd)
103 | # assert dataset_train
104 | # dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=opt.batchSize, shuffle=True,
105 | # num_workers=int(opt.workers))
106 | #
107 | # dataset_val = dset.ImageFolder(root=os.path.join(opt.dataset, opt.val_set), transform=transform_fwd)
108 | # assert dataset_val
109 | # dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=opt.batchSize, shuffle=False,
110 | # num_workers=int(opt.workers))
111 | dataloaders = {}
112 | for name in ['train', 'test']:
113 | raw_data = pandas.read_csv(os.path.join(opt.dataset, '%s.csv' % name))
114 | dataloaders[name] = DataLoader(FrameDataset(raw_data.to_numpy()),
115 | batch_size=opt.batch_size,
116 | shuffle=True,
117 | num_workers=4,
118 | pin_memory=False)
119 |
120 | for epoch in range(opt.resume + 1, opt.niter + 1):
121 | count = 0
122 | loss_train = 0
123 | loss_test = 0
124 |
125 | tol_label = np.array([], dtype=np.float)
126 | tol_pred = np.array([], dtype=np.float)
127 |
128 | for img_data, labels_data in tqdm(dataloaders['train']):
129 |
130 | labels_data[labels_data > 1] = 1
131 | img_label = labels_data.numpy().astype(np.float)
132 | optimizer.zero_grad()
133 |
134 | if opt.gpu_id >= 0:
135 | img_data = img_data.cuda(opt.gpu_id)
136 | labels_data = labels_data.cuda(opt.gpu_id)
137 |
138 | input_v = Variable(img_data)
139 | x = vgg_ext(input_v)
140 | classes, class_ = capnet(x, random=opt.random, dropout=opt.dropout)
141 |
142 | loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False))
143 | loss_dis_data = loss_dis.item()
144 |
145 | loss_dis.backward()
146 | optimizer.step()
147 |
148 | output_dis = class_.data.cpu().numpy()
149 | output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)
150 |
151 | for i in range(output_dis.shape[0]):
152 | if output_dis[i, 1] >= output_dis[i, 0]:
153 | output_pred[i] = 1.0
154 | else:
155 | output_pred[i] = 0.0
156 |
157 | tol_label = np.concatenate((tol_label, img_label))
158 | tol_pred = np.concatenate((tol_pred, output_pred))
159 |
160 | loss_train += loss_dis_data
161 | count += 1
162 |
163 | acc_train = metrics.accuracy_score(tol_label, tol_pred)
164 | loss_train /= count
165 |
166 | ########################################################################
167 |
168 | # do checkpointing & validation
169 | torch.save(capnet.state_dict(), os.path.join(opt.outf, 'capsule_%d.pt' % epoch))
170 | torch.save(optimizer.state_dict(), os.path.join(opt.outf, 'optim_%d.pt' % epoch))
171 |
172 | capnet.eval()
173 |
174 | tol_label = np.array([], dtype=np.float)
175 | tol_pred = np.array([], dtype=np.float)
176 |
177 | count = 0
178 |
179 | for img_data, labels_data in dataloaders['test']:
180 |
181 | labels_data[labels_data > 1] = 1
182 | img_label = labels_data.numpy().astype(np.float)
183 |
184 | if opt.gpu_id >= 0:
185 | img_data = img_data.cuda(opt.gpu_id)
186 | labels_data = labels_data.cuda(opt.gpu_id)
187 |
188 | input_v = Variable(img_data)
189 |
190 | x = vgg_ext(input_v)
191 | classes, class_ = capnet(x, random=False)
192 |
193 | loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False))
194 | loss_dis_data = loss_dis.item()
195 | output_dis = class_.data.cpu().numpy()
196 |
197 | output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)
198 |
199 | for i in range(output_dis.shape[0]):
200 | if output_dis[i, 1] >= output_dis[i, 0]:
201 | output_pred[i] = 1.0
202 | else:
203 | output_pred[i] = 0.0
204 |
205 | tol_label = np.concatenate((tol_label, img_label))
206 | tol_pred = np.concatenate((tol_pred, output_pred))
207 |
208 | loss_test += loss_dis_data
209 | count += 1
210 |
211 | acc_test = metrics.accuracy_score(tol_label, tol_pred)
212 | auc_test = metrics.roc_auc_score(tol_label, tol_pred)
213 | f1_test = metrics.f1_score(tol_label, tol_pred)
214 | recall_test = metrics.recall_score(tol_label, tol_pred)
215 | precision = metrics.precision_score(tol_label, tol_pred)
216 |
217 | loss_test /= count
218 |
219 | print('[Epoch %d] Train loss: %.4f acc: %.2f | Test loss: %.4f acc: %.2f auc: %.2f'
220 | % (epoch, loss_train, acc_train * 100, loss_test, acc_test * 100, auc_test * 100))
221 |
222 | text_writer.write('%d,%.4f,%.2f,%.4f,%.2f,%.2f,%.2f,%.2f,%.2f\n'
223 | % (epoch, loss_train, acc_train * 100, loss_test, acc_test * 100, auc_test * 100,
224 | f1_test * 100, recall_test * 100, precision * 100))
225 |
226 | text_writer.flush()
227 | capnet.train(mode=True)
228 |
229 | text_writer.close()
230 |
--------------------------------------------------------------------------------
/datasets/dataset_imagenet_dct_cupy.py:
--------------------------------------------------------------------------------
1 | # Optimized for DCT
2 | # Upsampling in the compressed domain
3 | import os
4 | import sys
5 | from datasets.vision import VisionDataset
6 | from PIL import Image
7 | import cv2
8 | import os.path
9 | import numpy as np
10 | import torch
11 | from turbojpeg import TurboJPEG
12 | from datasets import train_y_mean_resized, train_y_std_resized, train_cb_mean_resized, train_cb_std_resized, \
13 | train_cr_mean_resized, train_cr_std_resized
14 |
15 | def has_file_allowed_extension(filename, extensions):
16 | """Checks if a file is an allowed extension.
17 |
18 | Args:
19 | filename (string): path to a file
20 | extensions (tuple of strings): extensions to consider (lowercase)
21 |
22 | Returns:
23 | bool: True if the filename ends with one of given extensions
24 | """
25 | return filename.lower().endswith(extensions)
26 |
27 |
28 | def is_image_file(filename):
29 | """Checks if a file is an allowed image extension.
30 |
31 | Args:
32 | filename (string): path to a file
33 |
34 | Returns:
35 | bool: True if the filename ends with a known image extension
36 | """
37 | return has_file_allowed_extension(filename, IMG_EXTENSIONS)
38 |
39 |
40 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
41 | images = []
42 | dir = os.path.expanduser(dir)
43 | if not ((extensions is None) ^ (is_valid_file is None)):
44 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
45 | if extensions is not None:
46 | def is_valid_file(x):
47 | return has_file_allowed_extension(x, extensions)
48 | for target in sorted(class_to_idx.keys()):
49 | d = os.path.join(dir, target)
50 | if not os.path.isdir(d):
51 | continue
52 | for root, _, fnames in sorted(os.walk(d)):
53 | for fname in sorted(fnames):
54 | path = os.path.join(root, fname)
55 | if is_valid_file(path):
56 | item = (path, class_to_idx[target])
57 | images.append(item)
58 |
59 | return images
60 |
61 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
62 |
63 | def pil_loader(path):
64 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
65 | with open(path, 'rb') as f:
66 | img = Image.open(f)
67 | return img.convert('RGB')
68 |
69 | def accimage_loader(path):
70 | import accimage
71 | try:
72 | return accimage.Image(path)
73 | except IOError:
74 | # Potentially a decoding problem, fall back to PIL.Image
75 | return pil_loader(path)
76 |
77 | def opencv_loader(path, colorSpace='YCrCb'):
78 | image = cv2.imread(str(path))
79 | # cv2.imwrite('/mnt/ssd/kai.x/work/code/iftc/datasets/cvtransforms/test/raw.jpg', image)
80 | if colorSpace == "YCrCb":
81 | image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)
82 | # cv2.imwrite('/mnt/ssd/kai.x/work/code/iftc/datasets/cvtransforms/test/ycbcr.jpg', image)
83 | elif colorSpace == 'RGB':
84 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85 | return image
86 |
87 | def default_loader(path, backend='opencv', colorSpace='YCrCb'):
88 | from torchvision import get_image_backend
89 | if backend == 'opencv':
90 | return opencv_loader(path, colorSpace=colorSpace)
91 | elif get_image_backend() == 'accimage' and backend == 'acc':
92 | return accimage_loader(path)
93 | elif backend == 'pil':
94 | return pil_loader(path)
95 | else:
96 | raise NotImplementedError
97 |
98 | class DatasetFolderDCT(VisionDataset):
99 | def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None, subset=0):
100 | super(DatasetFolderDCT, self).__init__(root)
101 | self.transform = transform
102 | self.target_transform = target_transform
103 | classes, class_to_idx = self._find_classes(self.root)
104 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
105 | if len(samples) == 0:
106 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
107 | "Supported extensions are: " + ",".join(extensions)))
108 |
109 | self.loader = loader
110 | self.extensions = extensions
111 |
112 | self.classes = classes
113 | self.class_to_idx = class_to_idx
114 | self.samples = samples
115 | self.targets = [s[1] for s in samples]
116 | self.subset = list(map(int, subset.split(','))) if subset else []
117 |
118 | def _find_classes(self, dir):
119 | if sys.version_info >= (3, 5):
120 | # Faster and available in Python 3.5 and above
121 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
122 | else:
123 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
124 | classes.sort()
125 | class_to_idx = {classes[i]: i for i in range(len(classes))}
126 | return classes, class_to_idx
127 |
128 | def __getitem__(self, index):
129 | path, target = self.samples[index]
130 | # sample = self.loader(path, backend='opencv', colorSpace='YCrCb')
131 |
132 | sample = self.loader(path, backend='opencv', colorSpace='BGR')
133 |
134 | # with open(path, 'rb') as src:
135 | # buffer = src.read()
136 | # dct_y_bak, dct_cb_bak, dct_cr_bak = loads(buffer)
137 |
138 | if self.transform is not None:
139 | dct_y, dct_cb, dct_cr = self.transform(sample)
140 |
141 | # sample_resize = sample.resize((224*2, 224*2), resample=0)
142 | # PIL to numpy
143 | # sample = np.asarray(sample, dtype="uint8")
144 | # RGB to BGR
145 | # sample = sample[:, :, ::-1]
146 | # JPEG Encode
147 | # sample = np.ascontiguousarray(sample, dtype="uint8")
148 | # sample = self.jpeg.encode(sample, quality=100, jpeg_subsample=2)
149 | # dct_y, dct_cb, dct_cr = loads(sample) # 28
150 |
151 | # sample_resize = np.asarray(sample_resize)
152 | # sample_resize = sample_resize[:, :, ::-1]
153 | # sample_resize = np.ascontiguousarray(sample_resize, dtype="uint8")
154 | # sample_resize = self.jpeg.encode(sample_resize, quality=100)
155 | # _, dct_cb_resize, dct_cr_resize = loads(sample_resize) # 28
156 | # dct_cb_resize, dct_cr_resize = torch.from_numpy(dct_cb_resize).permute(2, 0, 1).float(), \
157 | # torch.from_numpy(dct_cr_resize).permute(2, 0, 1).float()
158 |
159 | # dct_y_unnormalized, dct_cb_unnormalized, dct_cr_unnormalized = loads(sample, normalized=False) # 28
160 | # dct_y_normalized, dct_cb_normalized, dct_cr_normalized = loads(sample, normalized=True) # 28
161 | # total_y = (dct_y-dct_y_bak).sum()
162 | # total_cb = (dct_cb-dct_cb_bak).sum()
163 | # total_cr = (dct_cr-dct_cr_bak).sum()
164 | # print('{}, {}, {}'.format(total_y, total_cb, total_cr))
165 | # dct_y, dct_cb, dct_cr = torch.from_numpy(dct_y).permute(2, 0, 1).float(), \
166 | # torch.from_numpy(dct_cb).permute(2, 0, 1).float(), \
167 | # torch.from_numpy(dct_cr).permute(2, 0, 1).float()
168 |
169 | # transform = transforms.Resize(28, interpolation=2)
170 | # dct_cb_resize2 = [transform(Image.fromarray(dct_c.numpy())) for dct_c in dct_cb]
171 |
172 | if self.subset:
173 | dct_y, dct_cb, dct_cr = dct_y[self.subset[0]:self.subset[1]], dct_cb[self.subset[0]:self.subset[1]], \
174 | dct_cr[self.subset[0]:self.subset[1]]
175 |
176 | return dct_y, dct_cb, dct_cr, target
177 |
178 | def __len__(self):
179 | return len(self.samples)
180 |
181 | class ImageFolderDCT(DatasetFolderDCT):
182 | def __init__(self, root, transform=None, target_transform=None,
183 | loader=default_loader, is_valid_file=None, subset=None):
184 | super(ImageFolderDCT, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
185 | transform=transform,
186 | target_transform=target_transform,
187 | is_valid_file=is_valid_file, subset=subset)
188 | self.imgs = self.samples
189 |
190 |
191 | if __name__ == '__main__':
192 | dataset = 'imagenet'
193 |
194 | import torch
195 | import datasets.cvtransforms as transforms
196 | import matplotlib.pyplot as plt
197 | from sklearn.preprocessing import minmax_scale
198 |
199 | # jpeg_encoder = TurboJPEG('/home/kai.x/work/local/lib/libturbojpeg.so')
200 | jpeg_encoder = TurboJPEG('/usr/lib/libturbojpeg.so')
201 | if dataset == 'imagenet':
202 | input_normalize = []
203 | input_normalize_y = transforms.Normalize(mean=train_y_mean_resized,
204 | std=train_y_std_resized)
205 | input_normalize_cb = transforms.Normalize(mean=train_cb_mean_resized,
206 | std=train_cb_std_resized)
207 | input_normalize_cr = transforms.Normalize(mean=train_cr_mean_resized,
208 | std=train_cr_std_resized)
209 | input_normalize.append(input_normalize_y)
210 | input_normalize.append(input_normalize_cb)
211 | input_normalize.append(input_normalize_cr)
212 | val_loader = torch.utils.data.DataLoader(
213 | # ImageFolderDCT('/mnt/ssd/kai.x/dataset/ILSVRC2012/val', transforms.Compose([
214 | ImageFolderDCT('/storage-t1/user/kaixu/datasets/ILSVRC2012/val', transforms.Compose([
215 | transforms.ToYCrCb(),
216 | transforms.TransformDCT(),
217 | transforms.UpsampleDCT(T=896, debug=False),
218 | transforms.CenterCropDCT(112),
219 | transforms.ToTensorDCT(),
220 | transforms.NormalizeDCT(
221 | train_y_mean_resized, train_y_std_resized,
222 | train_cb_mean_resized, train_cb_std_resized,
223 | train_cr_mean_resized, train_cr_std_resized),
224 | ])),
225 | batch_size=1, shuffle=False,
226 | num_workers=1, pin_memory=False)
227 |
228 | train_dataset = ImageFolderDCT('/storage-t1/user/kaixu/datasets/ILSVRC2012/train', transforms.Compose([
229 | transforms.RandomResizedCrop(224),
230 | transforms.RandomHorizontalFlip(),
231 | transforms.ToYCrCb(),
232 | transforms.ChromaSubsample(),
233 | transforms.UpsampleDCT(size=224, T=896, cuda=True, debug=False),
234 | transforms.ToTensorDCT(),
235 | transforms.NormalizeDCT(
236 | train_y_mean_resized, train_y_std_resized,
237 | train_cb_mean_resized, train_cb_std_resized,
238 | train_cr_mean_resized, train_cr_std_resized),
239 | ]))
240 |
241 | train_loader = torch.utils.data.DataLoader(
242 | train_dataset,
243 | batch_size=1, shuffle=False,
244 | num_workers=1, pin_memory=False)
245 |
246 | from torchvision.utils import save_image
247 | dct_y_mean_total, dct_y_std_total = [], []
248 | # for batch_idx, (dct_y, dct_cb, dct_cr, targets) in enumerate(val_loader):
249 | for batch_idx, (dct_y, dct_cb, dct_cr, targets) in enumerate(train_loader):
250 | coef = dct_y.numpy()
251 | dct_y_mean, dct_y_std = [], []
252 |
253 | for c in coef:
254 | c = c.reshape((64, -1))
255 | dct_y_mean.append([np.mean(x) for x in c])
256 | dct_y_std.append([np.std(x) for x in c])
257 |
258 | dct_y_mean_np = np.asarray(dct_y_mean).mean(axis=0)
259 | dct_y_std_np = np.asarray(dct_y_std).mean(axis=0)
260 | dct_y_mean_total.append(dct_y_mean_np)
261 | dct_y_std_total.append(dct_y_std_np)
262 | # print('The mean of dct_y is: {}'.format(dct_y_mean_np))
263 | # print('The std of dct_y is: {}'.format(dct_y_std_np))
264 |
265 | print('The mean of dct_y is: {}'.format(np.asarray(dct_y_mean_total).mean(axis=0)))
266 | print('The std of dct_y is: {}'.format(np.asarray(dct_y_std_total).mean(axis=0)))
267 |
268 |
269 |
--------------------------------------------------------------------------------
/datasets/imagenet2lmdb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import os, sys
4 | import os.path as osp
5 | from PIL import Image
6 | import six
7 | import string
8 |
9 | import lmdb
10 | import pickle
11 | import msgpack
12 | import tqdm
13 | import pyarrow as pa
14 | import bz2
15 |
16 | import torch
17 | import torch.utils.data as data
18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19 | import datasets.cvtransforms as transforms
20 | from datasets.dataset_imagenet_dct import ImageFolderDCT
21 |
22 | class ImageFolderLMDB(data.Dataset):
23 | def __init__(self, db_path, transform=None, target_transform=None):
24 | self.db_path = db_path
25 | self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
26 | readonly=True, lock=False,
27 | readahead=False, meminit=False)
28 | with self.env.begin(write=False) as txn:
29 | # self.length = txn.stat()['entries'] - 1
30 | self.length = txn.get(b'__len__')
31 | self.keys = msgpack.loads(txn.get(b'__keys__'))
32 |
33 | self.transform = transform
34 | self.target_transform = target_transform
35 |
36 | def __getitem__(self, index):
37 | img, target = None, None
38 | env = self.env
39 | with env.begin(write=False) as txn:
40 | byteflow = txn.get(self.keys[index])
41 | unpacked = msgpack.loads(byteflow)
42 |
43 | # load image
44 | imgbuf = unpacked[0]
45 | buf = six.BytesIO()
46 | buf.write(imgbuf)
47 | buf.seek(0)
48 | img = Image.open(buf).convert('RGB')
49 |
50 | # load label
51 | target = unpacked[1]
52 |
53 | if self.transform is not None:
54 | img = self.transform(img)
55 |
56 | if self.target_transform is not None:
57 | target = self.target_transform(target)
58 |
59 | return img, target
60 |
61 | def __len__(self):
62 | return self.length
63 |
64 | def __repr__(self):
65 | return self.__class__.__name__ + ' (' + self.db_path + ')'
66 |
67 |
68 | class ImageFolderLMDB_old(data.Dataset):
69 | def __init__(self, db_path, transform=None, target_transform=None):
70 | import lmdb
71 | self.db_path = db_path
72 | self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
73 | readonly=True, lock=False,
74 | readahead=False, meminit=False)
75 | with self.env.begin(write=False) as txn:
76 | self.length = txn.stat()['entries'] - 1
77 | self.keys = msgpack.loads(txn.get(b'__keys__'))
78 | # cache_file = '_cache_' + db_path.replace('/', '_')
79 | # if os.path.isfile(cache_file):
80 | # self.keys = pickle.load(open(cache_file, "rb"))
81 | # else:
82 | # with self.env.begin(write=False) as txn:
83 | # self.keys = [key for key, _ in txn.cursor()]
84 | # pickle.dump(self.keys, open(cache_file, "wb"))
85 | self.transform = transform
86 | self.target_transform = target_transform
87 |
88 | def __getitem__(self, index):
89 | img, target = None, None
90 | env = self.env
91 | with env.begin(write=False) as txn:
92 | byteflow = txn.get(self.keys[index])
93 | unpacked = msgpack.loads(byteflow)
94 | imgbuf = unpacked[0][b'data']
95 | buf = six.BytesIO()
96 | buf.write(imgbuf)
97 | buf.seek(0)
98 | img = Image.open(buf).convert('RGB')
99 | target = unpacked[1]
100 |
101 | if self.transform is not None:
102 | img = self.transform(img)
103 |
104 | if self.target_transform is not None:
105 | target = self.target_transform(target)
106 |
107 | return img, target
108 |
109 | def __len__(self):
110 | return self.length
111 |
112 | def __repr__(self):
113 | return self.__class__.__name__ + ' (' + self.db_path + ')'
114 |
115 |
116 | def raw_reader(path):
117 | with open(path, 'rb') as f:
118 | bin_data = f.read()
119 | return bin_data
120 |
121 |
122 | def dumps_pyarrow(obj):
123 | """
124 | Serialize an object.
125 |
126 | Returns:
127 | Implementation-dependent bytes-like object
128 | """
129 | return pa.serialize(obj).to_buffer()
130 |
131 | def folder2lmdb(dpath, name="train", write_frequency=1):
132 | directory = osp.expanduser(osp.join(dpath, name))
133 | print("Loading dataset from %s" % directory)
134 |
135 | dataset = ImageFolderDCT('/ILSVRC2012/train', transforms.Compose([
136 | transforms.DCTFlatten2D(),
137 | transforms.UpsampleDCT(upscale_ratio_h=4, upscale_ratio_w=4, debug=False),
138 | transforms.ToTensorDCT(),
139 | transforms.SubsetDCT(channels=32),
140 | ]), backend='dct')
141 |
142 | data_loader = torch.utils.data.DataLoader(
143 | dataset,
144 | num_workers=0,
145 | )
146 |
147 | lmdb_path = osp.join(dpath, "%s.lmdb" % name)
148 | isdir = os.path.isdir(lmdb_path)
149 |
150 | print("Generate LMDB to %s" % lmdb_path)
151 | db = lmdb.open(lmdb_path, subdir=isdir,
152 | map_size=1281167*224*224*32*10, readonly=False,
153 | # map_size=1099511627776 * 2, readonly=False,
154 | meminit=False, map_async=True)
155 |
156 | txn = db.begin(write=True)
157 | for idx, (image, label) in enumerate(data_loader):
158 | image = image.numpy()
159 | label = label.numpy()
160 | txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((bz2.compress(image), label)))
161 | if idx % write_frequency == 0:
162 | print("[%d/%d]" % (idx, len(data_loader)))
163 | txn.commit()
164 | txn = db.begin(write=True)
165 |
166 | # finish iterating through dataset
167 | txn.commit()
168 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
169 | with db.begin(write=True) as txn:
170 | txn.put(b'__keys__', dumps_pyarrow(keys))
171 | txn.put(b'__len__', dumps_pyarrow(len(keys)))
172 |
173 | print("Flushing database ...")
174 | db.sync()
175 | db.close()
176 |
177 |
178 | if __name__ == "__main__":
179 | folder2lmdb("/ILSVRC2012")
--------------------------------------------------------------------------------
/datasets/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, "transforms") and self.transforms is not None:
41 | body += [repr(self.transforms)]
42 | lines = [head] + [" " * self._repr_indent + line for line in body]
43 | return '\n'.join(lines)
44 |
45 | def _format_transform_repr(self, transform, head):
46 | lines = transform.__repr__().splitlines()
47 | return (["{}{}".format(head, lines[0])] +
48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
49 |
50 | def extra_repr(self):
51 | return ""
52 |
53 |
54 | class StandardTransform(object):
55 | def __init__(self, transform=None, target_transform=None):
56 | self.transform = transform
57 | self.target_transform = target_transform
58 |
59 | def __call__(self, input, target):
60 | if self.transform is not None:
61 | input = self.transform(input)
62 | if self.target_transform is not None:
63 | target = self.target_transform(target)
64 | return input, target
65 |
66 | def _format_transform_repr(self, transform, head):
67 | lines = transform.__repr__().splitlines()
68 | return (["{}{}".format(head, lines[0])] +
69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
70 |
71 | def __repr__(self):
72 | body = [self.__class__.__name__]
73 | if self.transform is not None:
74 | body += self._format_transform_repr(self.transform,
75 | "Transform: ")
76 | if self.target_transform is not None:
77 | body += self._format_transform_repr(self.target_transform,
78 | "Target transform: ")
79 |
80 | return '\n'.join(body)
81 |
--------------------------------------------------------------------------------
/dct/__init__.py:
--------------------------------------------------------------------------------
1 | subset_channel_index_square = {
2 | 1:
3 | [[0],[],[]],
4 |
5 | 6:
6 | [
7 | [0,1,
8 | 8,9],
9 | [0],
10 | [0]
11 | ],
12 |
13 | 12:
14 | [
15 | [0, 1, 2,
16 | 8, 9, 10,
17 | 16, 17],
18 | [0, 1],
19 | [0, 1]
20 | ],
21 |
22 | 24:
23 | [
24 | [0, 1, 2, 3,
25 | 8, 9, 10, 11,
26 | 16, 17, 18, 19,
27 | 24, 25, 26, 27],
28 | [0, 1,
29 | 8, 9],
30 | [0, 1,
31 | 8, 9]
32 | ],
33 |
34 | 32:
35 | [
36 | [0, 1, 2, 3, 4,
37 | 8, 9, 10, 11, 12,
38 | 16, 17, 18, 19, 20,
39 | 24, 25, 26, 27,
40 | 32, 33, 34],
41 | [0, 1, 2,
42 | 8, 9],
43 | [0, 1, 2,
44 | 8, 9]
45 | ],
46 |
47 | 48:
48 | [
49 | [0, 1, 2, 3, 4, 5,
50 | 8, 9, 10, 11, 12, 13,
51 | 16, 17, 18, 19, 20, 21,
52 | 24, 25, 26, 27, 28, 29,
53 | 32, 33, 34, 35,
54 | 40, 41, 42, 43],
55 | [0, 1, 2,
56 | 8, 9, 10,
57 | 16, 17],
58 | [0, 1, 2,
59 | 8, 9, 10,
60 | 16, 17]
61 | ],
62 |
63 | 64:
64 | [
65 | [0, 1, 2, 3, 4, 5, 6,
66 | 8, 9, 10, 11, 12, 13, 14,
67 | 16, 17, 18, 19, 20, 21,
68 | 24, 25, 26, 27, 28, 29,
69 | 32, 33, 34, 35, 36, 37,
70 | 40, 41, 42, 43, 44, 45,
71 | 48, 49, 50, 51, 52, 53],
72 | [0, 1, 2,
73 | 8, 9, 10,
74 | 16, 17,
75 | 24, 25],
76 | [0, 1, 2,
77 | 8, 9, 10,
78 | 16, 17,
79 | 24, 25],
80 | ]
81 | }
82 |
83 | subset_channel_index_learned = {
84 | 1:
85 | [[0], [], []],
86 |
87 |
88 |
89 | 24:
90 | [
91 | [0, 1, 2, 3, 4, 5,
92 | 8, 9, 10,
93 | 16, 17, 18,
94 | 24,
95 | 32],
96 | [0, 1, 3,
97 | 8,
98 | 24],
99 | [0, 1, 3,
100 | 8,
101 | 24]
102 | ]
103 | }
104 |
105 | subset_channel_index_triangle = {
106 | 1:
107 | [[0], [], []],
108 |
109 | 12:
110 | [
111 | [0, 1, 2,
112 | 8, 9,
113 | 16],
114 | [0, 1,
115 | 8],
116 | [0, 1,
117 | 8]
118 | ],
119 |
120 | 24:
121 | [
122 | [0, 1, 2, 3, 4,
123 | 8, 9, 10, 11,
124 | 16, 17,
125 | 24,],
126 | [0, 1, 2,
127 | 8, 9,
128 | 16],
129 | [0, 1, 2,
130 | 8, 9,
131 | 16]
132 | ],
133 |
134 | 48:
135 | [
136 | [0, 1, 2, 3, 4, 5, 6,
137 | 8, 9, 10, 11, 12, 13,
138 | 16, 17, 18, 19, 20,
139 | 24, 25, 26, 27,
140 | 32, 33, 34,
141 | 40, 41,
142 | 48],
143 | [0, 1, 2, 3,
144 | 8, 9, 10,
145 | 16, 17,
146 | 24],
147 | [0, 1, 2, 3,
148 | 8, 9, 10,
149 | 16, 17,
150 | 24]
151 | ],
152 |
153 | 64:
154 | [
155 | [0, 1, 2, 3, 4, 5, 6, 7,
156 | 8, 9, 10, 11, 12, 13, 14,
157 | 16, 17, 18, 19, 20, 21,
158 | 24, 25, 26, 27, 28,
159 | 32, 33, 34, 35,
160 | 40, 41, 42,
161 | 48,
162 | ],
163 | [0, 1, 2, 3, 4,
164 | 8, 9, 10, 11,
165 | 16, 17, 18,
166 | 24, 25,
167 | 32],
168 | [0, 1, 2, 3, 4,
169 | 8, 9, 10, 11,
170 | 16, 17, 18,
171 | 24, 25,
172 | 32],
173 | ]
174 | }
175 |
--------------------------------------------------------------------------------
/dct/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | # from .resnext import *
4 | from .resnet import *
5 | # from .resnet_autosubset_inputgate import resnet50_autosubset_inputgate
6 | # from .resnext_attention import *
7 | # from .mobilenetv2_autosubset_alllayer import mobilenetv2dct_autosubset_alllayers
8 |
9 |
--------------------------------------------------------------------------------
/dct/imagenet/gate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from dct.imagenet.gumbel import GumbleSoftmax
4 |
5 | class GateModule(nn.Module):
6 | def __init__(self, in_ch, kernel_size=28, doubleGate=False, dwLA=False):
7 | super(GateModule, self).__init__()
8 |
9 | self.doubleGate, self.dwLA = doubleGate, dwLA
10 | self.inp_gs = GumbleSoftmax()
11 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
12 | self.in_ch = in_ch
13 |
14 | if dwLA:
15 | if doubleGate:
16 | self.inp_att = nn.Sequential(
17 | nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, stride=1, padding=0, groups=in_ch,
18 | bias=True),
19 | nn.BatchNorm2d(in_ch),
20 | nn.ReLU6(inplace=True),
21 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True),
22 | nn.Sigmoid()
23 | )
24 |
25 | self.inp_gate = nn.Sequential(
26 | nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, stride=1, padding=0, groups=in_ch, bias=True),
27 | nn.BatchNorm2d(in_ch),
28 | nn.ReLU6(inplace=True),
29 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True),
30 | nn.BatchNorm2d(in_ch),
31 | )
32 | self.inp_gate_l = nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, stride=1, padding=0, groups=in_ch,
33 | bias=True)
34 | else:
35 | if doubleGate:
36 | reduction = 4
37 | self.inp_att = nn.Sequential(
38 | nn.Conv2d(in_ch, in_ch // reduction, kernel_size=1, stride=1, padding=0, bias=True),
39 | nn.ReLU6(inplace=True),
40 | nn.Conv2d(in_ch // reduction, in_ch, kernel_size=1, stride=1, padding=0, bias=True),
41 | nn.Sigmoid()
42 | )
43 |
44 | self.inp_gate = nn.Sequential(
45 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True),
46 | nn.BatchNorm2d(in_ch),
47 | nn.ReLU6(inplace=True),
48 | )
49 | self.inp_gate_l = nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, stride=1, padding=0, groups=in_ch, bias=True)
50 |
51 | def forward(self, y, cb, cr, temperature=1.):
52 | if self.doubleGate:
53 | if self.dwLA:
54 | hatten_d1 = self.inp_att(x)
55 | hatten_d2 = self.inp_gate(x)
56 | hatten_d2 = self.inp_gate_l(hatten_d2)
57 | else:
58 | hatten_y, hatten_cb, hatten_cr = self.avg_pool(y), self.avg_pool(cb), self.avg_pool(cr)
59 | hatten = torch.cat((hatten_y, hatten_cb, hatten_cr), dim=1)
60 |
61 | hatten_d1 = self.inp_att(hatten)
62 | hatten_d2 = self.inp_gate(hatten)
63 | hatten_d2 = self.inp_gate_l(hatten_d2)
64 |
65 | hatten_d2 = hatten_d2.reshape(hatten_d2.size(0), self.in_ch, 2, 1)
66 | hatten_d2 = self.inp_gs(hatten_d2, temp=temperature, force_hard=True)
67 | else:
68 | if self.dwLA:
69 | hatten_d2 = self.inp_gate(x)
70 | hatten_d2 = self.inp_gate_l(hatten_d2)
71 | else:
72 | hatten_y, hatten_cb, hatten_cr = self.avg_pool(y), self.avg_pool(cb), self.avg_pool(cr)
73 | hatten_d2 = torch.cat((hatten_y, hatten_cb, hatten_cr), dim=1)
74 | hatten_d2 = self.inp_gate(hatten_d2)
75 | hatten_d2 = self.inp_gate_l(hatten_d2)
76 |
77 | hatten_d2 = hatten_d2.reshape(hatten_d2.size(0), self.in_ch, 2, 1)
78 | hatten_d2 = self.inp_gs(hatten_d2, temp=temperature, force_hard=True)
79 |
80 | if self.doubleGate:
81 | x = x * hatten_d1 * hatten_d2[:, :, 1].unsqueeze(2)
82 | else:
83 | y = y * hatten_d2[:, :64, 1].unsqueeze(2)
84 | cb = cb * hatten_d2[:, 64:128, 1].unsqueeze(2)
85 | cr = cr * hatten_d2[:, 128:, 1].unsqueeze(2)
86 |
87 | return y, cb, cr, hatten_d2[:, :, 1]
88 |
89 |
90 |
91 | class GateModule192(nn.Module):
92 | def __init__(self, act='relu'):
93 | super(GateModule192, self).__init__()
94 |
95 | self.inp_gs = GumbleSoftmax()
96 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
97 | self.in_ch = in_ch = 192
98 | if act == 'relu':
99 | relu = nn.ReLU
100 | elif act == 'relu6':
101 | relu = nn.ReLU6
102 | else: raise NotImplementedError
103 |
104 | self.inp_gate = nn.Sequential(
105 | nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=0, bias=True),
106 | nn.BatchNorm2d(in_ch),
107 | relu(inplace=True),
108 | )
109 | self.inp_gate_l = nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, stride=1, padding=0, groups=in_ch, bias=True)
110 |
111 |
112 | def forward(self, x, temperature=1.):
113 | hatten = self.avg_pool(x)
114 | hatten_d = self.inp_gate(hatten)
115 | hatten_d = self.inp_gate_l(hatten_d)
116 | hatten_d = hatten_d.reshape(hatten_d.size(0), self.in_ch, 2, 1)
117 | hatten_d = self.inp_gs(hatten_d, temp=temperature, force_hard=True)
118 |
119 | x = x * hatten_d[:, :, 1].unsqueeze(2)
120 |
121 | return x, hatten_d[:, :, 1]
122 |
--------------------------------------------------------------------------------
/dct/imagenet/gumbel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 |
5 | class GumbleSoftmax(torch.nn.Module):
6 | def __init__(self, hard=False):
7 | super(GumbleSoftmax, self).__init__()
8 | self.hard = hard
9 |
10 | def sample_gumbel(self, shape, eps=1e-10):
11 | """Sample from Gumbel(0, 1)"""
12 | noise = torch.rand(shape)
13 | noise.add_(eps).log_().neg_()
14 | noise.add_(eps).log_().neg_()
15 | if self.gpu:
16 | return noise.cuda()
17 | else:
18 | return noise
19 |
20 | def sample_gumbel_like(self, template_tensor, eps=1e-10):
21 | uniform_samples_tensor = template_tensor.clone().uniform_()
22 | gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps))
23 | return gumble_samples_tensor
24 |
25 | def gumbel_softmax_sample(self, logits, temperature):
26 | """ Draw a sample from the Gumbel-Softmax distribution"""
27 | dim = logits.size(2)
28 | gumble_samples_tensor = self.sample_gumbel_like(logits.data)
29 | gumble_trick_log_prob_samples = logits + gumble_samples_tensor
30 | soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, dim)
31 | return soft_samples
32 |
33 | def gumbel_softmax(self, logits, temperature, hard=False):
34 | """Sample from the Gumbel-Softmax distribution and optionally discretize.
35 | Args:
36 | logits: [batch_size, n_class] unnormalized log-probslibaba
37 | temperature: non-negative scalar
38 | hard: if True, take argmax, but differentiate w.r.t. soft sample y
39 | Returns:
40 | [batch_size, n_class] sample from the Gumbel-Softmax distribution.
41 | If hard=True, then the returned sample will be one-hot, otherwise it will
42 | be a probabilitiy distribution that sums to 1 across classes
43 | """
44 | y = self.gumbel_softmax_sample(logits, temperature)
45 | if hard:
46 | # block layer
47 | # _, max_value_indexes = y.data.max(1, keepdim=True)
48 | # y_hard = logits.data.clone().zero_().scatter_(1, max_value_indexes, 1)
49 | # block channel
50 | _, max_value_indexes = y.data.max(2, keepdim=True)
51 | y_hard = logits.data.clone().zero_().scatter_(2, max_value_indexes, 1)
52 | y = Variable(y_hard - y.data) + y
53 | return y
54 |
55 | def forward(self, logits, temp=1, force_hard=False):
56 | samplesize = logits.size()
57 |
58 | if self.training and not force_hard:
59 | return self.gumbel_softmax(logits, temperature=1, hard=False)
60 | else:
61 | return self.gumbel_softmax(logits, temperature=1, hard=True)
--------------------------------------------------------------------------------
/dct/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 |
5 | def constant_init(module, val, bias=0):
6 | nn.init.constant_(module.weight, val)
7 | if hasattr(module, 'bias') and module.bias is not None:
8 | nn.init.constant_(module.bias, bias)
9 |
10 |
11 | def xavier_init(module, gain=1, bias=0, distribution='normal'):
12 | assert distribution in ['uniform', 'normal']
13 | if distribution == 'uniform':
14 | nn.init.xavier_uniform_(module.weight, gain=gain)
15 | else:
16 | nn.init.xavier_normal_(module.weight, gain=gain)
17 | if hasattr(module, 'bias') and module.bias is not None:
18 | nn.init.constant_(module.bias, bias)
19 |
20 |
21 | def normal_init(module, mean=0, std=1, bias=0):
22 | nn.init.normal_(module.weight, mean, std)
23 | if hasattr(module, 'bias') and module.bias is not None:
24 | nn.init.constant_(module.bias, bias)
25 |
26 |
27 | def uniform_init(module, a=0, b=1, bias=0):
28 | nn.init.uniform_(module.weight, a, b)
29 | if hasattr(module, 'bias') and module.bias is not None:
30 | nn.init.constant_(module.bias, bias)
31 |
32 |
33 | def kaiming_init(module,
34 | a=0,
35 | mode='fan_out',
36 | nonlinearity='relu',
37 | bias=0,
38 | distribution='normal'):
39 | assert distribution in ['uniform', 'normal']
40 | if distribution == 'uniform':
41 | nn.init.kaiming_uniform_(
42 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
43 | else:
44 | nn.init.kaiming_normal_(
45 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
46 | if hasattr(module, 'bias') and module.bias is not None:
47 | nn.init.constant_(module.bias, bias)
48 |
49 |
50 | def caffe2_xavier_init(module, bias=0):
51 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
52 | # Acknowledgment to FAIR's internal code
53 | kaiming_init(
54 | module,
55 | a=1,
56 | mode='fan_in',
57 | nonlinearity='leaky_relu',
58 | distribution='uniform')
59 |
60 |
61 | def get_upsample_filter(size):
62 | """Make a 2D bilinear kernel suitable for upsampling"""
63 | factor = (size + 1) // 2
64 | if size % 2 == 1:
65 | center = factor - 1
66 | else:
67 | center = factor - 0.5
68 | og = np.ogrid[:size, :size]
69 | filter = (1 - abs(og[0] - center) / factor) * \
70 | (1 - abs(og[1] - center) / factor)
71 | return torch.from_numpy(filter).float()
72 |
--------------------------------------------------------------------------------
/demo/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import cv2
6 | from facenet_pytorch.models.mtcnn import MTCNN
7 | from sklearn.metrics import accuracy_score
8 | from models.model import Baseline
9 |
10 |
11 | def load_model(restore_from, device):
12 | model = Baseline(use_gru=True, bi_branch=True)
13 |
14 | model.to(device)
15 |
16 | device_count = torch.cuda.device_count()
17 | # if device_count > 1:
18 | # print('Using {} GPUs'.format(device_count))
19 | model = nn.DataParallel(model)
20 |
21 | if restore_from is not None:
22 | ckpt = torch.load(restore_from, map_location='cpu')
23 | model.load_state_dict(ckpt['model_state_dict'])
24 | print('Model is loaded from %s' % restore_from)
25 |
26 | model.eval()
27 |
28 | return model
29 |
30 | def _bbox_in_img(img, bbox):
31 | """
32 | check whether the bbox is inner an image.
33 | :param img: (3-d np.ndarray), image
34 | :param bbox: (list) [x, y, width, height]
35 | :return: (bool), whether bbox in image size.
36 | """
37 | if not isinstance(img, np.ndarray):
38 | raise ValueError("input image should be ndarray!")
39 | if len(img.shape) != 3:
40 | raise ValueError("input image should be (w,h,c)!")
41 | h = img.shape[0]
42 | w = img.shape[1]
43 | x_in = 0 <= bbox[0] <= w
44 | y_in = 0 <= bbox[1] <= h
45 | x1_in = 0 <= bbox[0] + bbox[2] <= w
46 | y1_in = 0 <= bbox[1] + bbox[3] <= h
47 | return x_in and y_in and x1_in and y1_in
48 |
49 |
50 | def _enlarged_bbox(bbox, expand):
51 | """
52 | enlarge a bbox by given expand param.
53 | :param bbox: [x, y, width, height]
54 | :param expand: (tuple) (h,w), expanded pixels in height and width. if (int), same value in both side.
55 | :return: enlarged bbox
56 | """
57 | if isinstance(expand, int):
58 | expand = (expand, expand)
59 | s_0, s_1 = bbox[1], bbox[0]
60 | e_0, e_1 = bbox[1] + bbox[3], bbox[0] + bbox[2]
61 | x = s_1 - expand[1]
62 | y = s_0 - expand[0]
63 | x1 = e_1 + expand[1]
64 | y1 = e_0 + expand[0]
65 | width = x1 - x
66 | height = y1 - y
67 | return x, y, width, height
68 |
69 |
70 | def _box_mode_cvt(bbox):
71 | """
72 | convert box from FCOS([xyxy], float) output to [x, y, width, height](int).
73 | :param bbox: (dict), an output from FCOS([x, y, x1, y1], float).
74 | :return: (list[int]), a box with [x, y, width, height] format.
75 | """
76 | if bbox is None:
77 | raise ValueError("There is no box in the dict!")
78 | # FCOS box format is [x, y, x1, y1]
79 | w = bbox[2] - bbox[0]
80 | h = bbox[3] - bbox[1]
81 | cvt_box = [int(bbox[0]), int(bbox[1]), max(int(w), 0), max(int(h), 0)]
82 | return cvt_box
83 |
84 |
85 | def crop_bbox(img, bbox):
86 | """
87 | crop an image by giving exact bbox.
88 | :param img:
89 | :param bbox: [x, y, width, height]
90 | :return: cropped image
91 | """
92 | if not _bbox_in_img(img, bbox):
93 | raise ValueError("bbox is out of image size!img size: {0}, bbox size: {1}".format(img.shape, bbox))
94 | s_0 = bbox[1]
95 | s_1 = bbox[0]
96 | e_0 = bbox[1] + bbox[3]
97 | e_1 = bbox[0] + bbox[2]
98 | cropped_img = img[s_0:e_0, s_1:e_1, :]
99 | return cropped_img
100 |
101 | def face_boxes_post_process(img, box, expand_ratio):
102 | """
103 | enlarge and crop the face patch from image
104 | :param img: ndarray, 1 frame from video
105 | :param box: output of MTCNN
106 | :param expand_ratio: default: 1.3
107 | :return:
108 | """
109 | box = [max(b, 0) for b in box]
110 | box_xywh = _box_mode_cvt(box)
111 | expand_w = int((box_xywh[2] * (expand_ratio - 1)) / 2)
112 | expand_h = int((box_xywh[3] * (expand_ratio - 1)) / 2)
113 | enlarged_box = _enlarged_bbox(box_xywh, (expand_h, expand_w))
114 | try:
115 | res = crop_bbox(img, enlarged_box)
116 | except ValueError:
117 | try:
118 | res = crop_bbox(img, box_xywh)
119 | except ValueError:
120 | return img
121 | return res
122 |
123 | def detect_face(frame, face_detector):
124 | boxes, _ = face_detector.detect(frame)
125 | if boxes is not None:
126 | best_box = boxes[0, :]
127 | best_face = face_boxes_post_process(frame, best_box, expand_ratio=1.33)
128 | return best_face
129 | else:
130 | return None
131 |
132 |
133 | def load_data(path, device):
134 | transform = transforms.Compose([
135 | transforms.ToTensor(),
136 | ])
137 | face_detector = MTCNN(margin=0, keep_all=False, select_largest=False, thresholds=[0.6, 0.7, 0.7],
138 | min_face_size=60, factor=0.8, device=device).eval()
139 | video_fd = cv2.VideoCapture(path)
140 | if not video_fd.isOpened():
141 | print('problem of reading video')
142 | return
143 |
144 | frame_index = 0
145 | faces = []
146 | success, frame = video_fd.read()
147 | while success:
148 | cropped_face = detect_face(frame, face_detector)
149 | cropped_face = cv2.resize(cropped_face, (64, 64))
150 | if cropped_face is not None:
151 | cropped_face = transform(cropped_face)
152 | faces.append(cropped_face)
153 | frame_index += 1
154 | success, frame = video_fd.read()
155 | video_fd.release()
156 | print('video frame length:', frame_index)
157 | faces = torch.stack(faces, dim=0)
158 | faces = torch.unsqueeze(faces, 0)
159 | y = torch.ones(frame_index).type(torch.IntTensor)
160 | return faces, y
161 |
162 |
163 | def main(args):
164 | frame_y_gd = []
165 | y_pred = []
166 | frame_y_pred = []
167 | use_cuda = torch.cuda.is_available()
168 | device = torch.device('cuda' if use_cuda else 'cpu')
169 | model = load_model(args.restore_from, device)
170 | data, y = load_data(args.path, device)
171 | X = data.to(device)
172 | y_, cnn_y = model(X)
173 | y_ = torch.sigmoid(y_)
174 | frame_y_ = torch.sigmoid(cnn_y)
175 | frame_y_gd += y.detach().numpy().tolist()
176 | frame_y_pred += frame_y_.detach().numpy().tolist()
177 | frame_y_pred = torch.tensor(frame_y_pred)
178 | frame_y_pred = [0 if i < 0.5 else 1 for i in frame_y_pred]
179 | test_frame_acc = accuracy_score(frame_y_gd, frame_y_pred)
180 | print('video is fake:', (y_ >= 0.5).item())
181 | print('frame level acc:', test_frame_acc)
182 |
183 |
184 | if __name__ == '__main__':
185 | import argparse
186 |
187 | parser = argparse.ArgumentParser()
188 | parser.add_argument('--restore_from', type=str, default='./bi-model_type-baseline_gru_auc_0.150000_ep-10.pth')
189 | parser.add_argument('--path', type=str, default='./video/id0_id1_0002.mp4')
190 | args = parser.parse_args()
191 | main(args)
192 |
--------------------------------------------------------------------------------
/demo/video/id0_id1_0002.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/demo/video/id0_id1_0002.mp4
--------------------------------------------------------------------------------
/fwa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/fwa/__init__.py
--------------------------------------------------------------------------------
/fwa/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision import models
4 | # import torch.nn.functional as F
5 | import math
6 |
7 |
8 | class ResNet(nn.Module):
9 | def __init__(self, layers=18, num_class=2, pretrained=True):
10 | super(ResNet, self).__init__()
11 | if layers == 18:
12 | self.resnet = models.resnet18(pretrained=pretrained)
13 | elif layers == 34:
14 | self.resnet = models.resnet34(pretrained=pretrained)
15 | elif layers == 50:
16 | self.resnet = models.resnet50(pretrained=pretrained)
17 | elif layers == 101:
18 | self.resnet = models.resnet101(pretrained=pretrained)
19 | elif layers == 152:
20 | self.resnet = models.resnet152(pretrained=pretrained)
21 | else:
22 | raise ValueError('layers should be 18, 34, 50, 101.')
23 | self.num_class = num_class
24 | if layers in [18, 34]:
25 | self.fc = nn.Linear(512, num_class)
26 | if layers in [50, 101, 152]:
27 | self.fc = nn.Linear(512 * 4, num_class)
28 |
29 | def conv_base(self, x):
30 | x = self.resnet.conv1(x)
31 | x = self.resnet.bn1(x)
32 | x = self.resnet.relu(x)
33 | x = self.resnet.maxpool(x)
34 |
35 | layer1 = self.resnet.layer1(x)
36 | layer2 = self.resnet.layer2(layer1)
37 | layer3 = self.resnet.layer3(layer2)
38 | layer4 = self.resnet.layer4(layer3)
39 | return layer1, layer2, layer3, layer4
40 |
41 | def forward(self, x):
42 | layer1, layer2, layer3, layer4 = self.conv_base(x)
43 | x = self.resnet.avgpool(layer4)
44 | x = x.view(x.size(0), -1)
45 | x = self.fc(x)
46 | return x
47 |
48 |
49 | class SPPNet(nn.Module):
50 | def __init__(self, backbone=101, num_class=2, pool_size=(1, 2, 6), pretrained=True):
51 | # Only resnet is supported in this version
52 | super(SPPNet, self).__init__()
53 | if backbone in [18, 34, 50, 101, 152]:
54 | self.resnet = ResNet(backbone, num_class, pretrained)
55 | else:
56 | raise ValueError('Resnet{} is not supported yet.'.format(backbone))
57 |
58 | if backbone in [18, 34]:
59 | self.c = 512
60 | if backbone in [50, 101, 152]:
61 | self.c = 2048
62 |
63 | self.spp = SpatialPyramidPool2D(out_side=pool_size)
64 | num_features = self.c * (pool_size[0] ** 2 + pool_size[1] ** 2 + pool_size[2] ** 2)
65 | self.classifier = nn.Linear(num_features, num_class)
66 |
67 | def forward(self, x):
68 | _, _, _, x = self.resnet.conv_base(x)
69 | x = self.spp(x)
70 | x = self.classifier(x)
71 | return x
72 |
73 |
74 | class SpatialPyramidPool2D(nn.Module):
75 | """
76 | Args:
77 | out_side (tuple): Length of side in the pooling results of each pyramid layer.
78 |
79 | Inputs:
80 | - `input`: the input Tensor to invert ([batch, channel, width, height])
81 | """
82 |
83 | def __init__(self, out_side):
84 | super(SpatialPyramidPool2D, self).__init__()
85 | self.out_side = out_side
86 |
87 | def forward(self, x):
88 | # batch_size, c, h, w = x.size()
89 | out = None
90 | for n in self.out_side:
91 | w_r, h_r = map(lambda s: math.ceil(s / n), x.size()[2:]) # Receptive Field Size
92 | s_w, s_h = map(lambda s: math.floor(s / n), x.size()[2:]) # Stride
93 | max_pool = nn.MaxPool2d(kernel_size=(w_r, h_r), stride=(s_w, s_h))
94 | y = max_pool(x)
95 | if out is None:
96 | out = y.view(y.size()[0], -1)
97 | else:
98 | out = torch.cat((out, y.view(y.size()[0], -1)), 1)
99 | return out
100 |
--------------------------------------------------------------------------------
/imgs/imbalanced performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/imgs/imbalanced performance.png
--------------------------------------------------------------------------------
/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/imgs/overview.png
--------------------------------------------------------------------------------
/meso/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/meso/__init__.py
--------------------------------------------------------------------------------
/meso/eval_meso.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
4 |
5 | import numpy as np
6 | from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, roc_curve
7 |
8 | from tqdm import tqdm
9 |
10 | from meso.meso import *
11 |
12 | validation_data_path = ''
13 |
14 | os.environ["CUDA_VISIBLE_DEVICES"] = '2, 3'
15 |
16 | img_width, img_height = 64, 64
17 | batch_size = 2000
18 | epochs = 20
19 |
20 | frame_y_gd = []
21 | frame_y_pred = []
22 |
23 | model = MesoInception4()
24 | # model = Meso4()
25 | model.load('')
26 |
27 | test_datagen = ImageDataGenerator(rescale=1. / 255)
28 |
29 | validation_generator = test_datagen.flow_from_directory(
30 | validation_data_path,
31 | target_size=(img_height, img_width),
32 | batch_size=batch_size,
33 | class_mode='binary')
34 |
35 | i = 0
36 | for X, y in tqdm(validation_generator, desc='Validating'):
37 | y_ = model.predict(X)
38 | frame_y_pred += y_.tolist()
39 | frame_y_gd += y.tolist()
40 | i += 1
41 | if i >= 37:
42 | break
43 |
44 | gd = np.array(frame_y_gd)
45 | pred = np.array(frame_y_pred)
46 | pred_pro = pred
47 |
48 | pred = np.rint(pred)
49 | f_fpr, f_tpr, _ = roc_curve(gd, pred_pro)
50 | test_frame_acc = accuracy_score(gd, pred)
51 | test_frame_auc = roc_auc_score(gd, pred_pro)
52 | test_frame_f1 = f1_score(gd, pred)
53 | test_frame_pre = precision_score(gd, pred)
54 | test_frame_recall = recall_score(gd, pred)
55 |
56 | np.save('', f_fpr)
57 | np.save('', f_tpr)
58 |
59 | print('acc:, auc:, f1_score, precision_score, recall_score')
60 | print(test_frame_acc, test_frame_auc, test_frame_f1, test_frame_pre, test_frame_recall)
61 |
--------------------------------------------------------------------------------
/meso/meso.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import tensorflow as tf
3 | from tensorflow.keras import backend as K
4 | from tensorflow.keras.models import Model as KerasModel
5 | from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Reshape, \
6 | Concatenate, LeakyReLU
7 | from tensorflow.keras.optimizers import Adam
8 |
9 | from sklearn.metrics import roc_auc_score
10 |
11 | IMGWIDTH = 64
12 |
13 |
14 | def getPrecision(y_true, y_pred):
15 | TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) # TP
16 | N = (-1) * K.sum(K.round(K.clip(y_true - K.ones_like(y_true), -1, 0))) # N
17 | TN = K.sum(K.round(K.clip((y_true - K.ones_like(y_true)) * (y_pred - K.ones_like(y_pred)), 0, 1))) # TN
18 | FP = N - TN
19 | precision = TP / (TP + FP + K.epsilon()) # TT/P
20 | return precision
21 |
22 |
23 | def auroc(y_true, y_pred):
24 | return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
25 |
26 |
27 | def f1(y_true, y_pred):
28 | def recall(y_true, y_pred):
29 | """Recall metric.
30 |
31 | Only computes a batch-wise average of recall.
32 |
33 | Computes the recall, a metric for multi-label classification of
34 | how many relevant items are selected.
35 | """
36 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
37 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
38 | recall = true_positives / (possible_positives + K.epsilon())
39 | return recall
40 |
41 | def precision(y_true, y_pred):
42 | """Precision metric.
43 |
44 | Only computes a batch-wise average of precision.
45 |
46 | Computes the precision, a metric for multi-label classification of
47 | how many selected items are relevant.
48 | """
49 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
50 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
51 | precision = true_positives / (predicted_positives + K.epsilon())
52 | return precision
53 |
54 | precision = precision(y_true, y_pred)
55 | recall = recall(y_true, y_pred)
56 | return 2 * ((precision * recall) / (precision + recall + K.epsilon()))
57 |
58 |
59 | class Classifier:
60 | def __init__(self):
61 | self.model = 0
62 |
63 | def predict(self, x):
64 | return self.model.predict(x)
65 |
66 | def fit(self, x, y):
67 | return self.model.train_on_batch(x, y)
68 |
69 | def get_accuracy(self, x, y):
70 | return self.model.test_on_batch(x, y)
71 |
72 | def get_auc(self, x, y):
73 | return auroc(x, y)
74 |
75 | def load(self, path):
76 | self.model.load_weights(path)
77 |
78 |
79 | class Meso1(Classifier):
80 | """
81 | Feature extraction + Classification
82 | """
83 |
84 | def __init__(self, learning_rate=1e-4, dl_rate=1):
85 | self.model = self.init_model(dl_rate)
86 | optimizer = Adam(lr=learning_rate)
87 | self.model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', auroc])
88 |
89 | def init_model(self, dl_rate):
90 | x = Input(shape=(IMGWIDTH, IMGWIDTH, 3))
91 |
92 | x1 = Conv2D(16, (3, 3), dilation_rate=dl_rate, strides=1, padding='same', activation='relu')(x)
93 | x1 = Conv2D(4, (1, 1), padding='same', activation='relu')(x1)
94 | x1 = BatchNormalization()(x1)
95 | x1 = MaxPooling2D(pool_size=(8, 8), padding='same')(x1)
96 |
97 | y = Flatten()(x1)
98 | y = Dropout(0.5)(y)
99 | y = Dense(1, activation='sigmoid')(y)
100 | return KerasModel(inputs=x, outputs=y)
101 |
102 |
103 | class Meso4(Classifier):
104 | def __init__(self, learning_rate=1e-5):
105 | self.model = self.init_model()
106 | optimizer = Adam(lr=learning_rate)
107 | self.model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', 'AUC', f1,
108 | 'Recall'])
109 |
110 | def init_model(self):
111 | x = Input(shape=(IMGWIDTH, IMGWIDTH, 3))
112 |
113 | x1 = Conv2D(8, (3, 3), padding='same', activation='relu')(x)
114 | x1 = BatchNormalization()(x1)
115 | x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
116 |
117 | x2 = Conv2D(8, (5, 5), padding='same', activation='relu')(x1)
118 | x2 = BatchNormalization()(x2)
119 | x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
120 |
121 | x3 = Conv2D(16, (5, 5), padding='same', activation='relu')(x2)
122 | x3 = BatchNormalization()(x3)
123 | x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
124 |
125 | x4 = Conv2D(16, (5, 5), padding='same', activation='relu')(x3)
126 | x4 = BatchNormalization()(x4)
127 | x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
128 |
129 | y = Flatten()(x4)
130 | y = Dropout(0.5)(y)
131 | y = Dense(16)(y)
132 | y = LeakyReLU(alpha=0.1)(y)
133 | y = Dropout(0.5)(y)
134 | y = Dense(1, activation='sigmoid')(y)
135 |
136 | return KerasModel(inputs=x, outputs=y)
137 |
138 |
139 | class MesoInception4(Classifier):
140 | def __init__(self, learning_rate=0.001):
141 | self.model = self.init_model()
142 | optimizer = Adam(lr=learning_rate)
143 | self.model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=['accuracy', 'AUC', f1,
144 | 'Recall'])
145 |
146 | def InceptionLayer(self, a, b, c, d):
147 | def func(x):
148 | x1 = Conv2D(a, (1, 1), padding='same', activation='relu')(x)
149 |
150 | x2 = Conv2D(b, (1, 1), padding='same', activation='relu')(x)
151 | x2 = Conv2D(b, (3, 3), padding='same', activation='relu')(x2)
152 |
153 | x3 = Conv2D(c, (1, 1), padding='same', activation='relu')(x)
154 | x3 = Conv2D(c, (3, 3), dilation_rate=2, strides=1, padding='same', activation='relu')(x3)
155 |
156 | x4 = Conv2D(d, (1, 1), padding='same', activation='relu')(x)
157 | x4 = Conv2D(d, (3, 3), dilation_rate=3, strides=1, padding='same', activation='relu')(x4)
158 |
159 | y = Concatenate(axis=-1)([x1, x2, x3, x4])
160 |
161 | return y
162 |
163 | return func
164 |
165 | def init_model(self):
166 | x = Input(shape=(IMGWIDTH, IMGWIDTH, 3))
167 |
168 | x1 = self.InceptionLayer(1, 4, 4, 2)(x)
169 | x1 = BatchNormalization()(x1)
170 | x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
171 |
172 | x2 = self.InceptionLayer(2, 4, 4, 2)(x1)
173 | x2 = BatchNormalization()(x2)
174 | x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
175 |
176 | x3 = Conv2D(16, (5, 5), padding='same', activation='relu')(x2)
177 | x3 = BatchNormalization()(x3)
178 | x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
179 |
180 | x4 = Conv2D(16, (5, 5), padding='same', activation='relu')(x3)
181 | x4 = BatchNormalization()(x4)
182 | x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
183 |
184 | y = Flatten()(x4)
185 | y = Dropout(0.5)(y)
186 | y = Dense(16)(y)
187 | y = LeakyReLU(alpha=0.1)(y)
188 | y = Dropout(0.5)(y)
189 | y = Dense(1, activation='sigmoid')(y)
190 |
191 | return KerasModel(inputs=x, outputs=y)
192 |
--------------------------------------------------------------------------------
/meso/train_mesonet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
4 | from tensorflow.keras import callbacks
5 |
6 | import time
7 |
8 | from meso.meso import *
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'
11 | start = time.time()
12 |
13 | img_width, img_height = 64, 64
14 | batch_size = 2000
15 | epochs = 20
16 |
17 | train_data_path = ''
18 | validation_data_path = ''
19 |
20 | # model = Meso4().model
21 | model = MesoInception4().model
22 |
23 | train_datagen = ImageDataGenerator(rescale=1. / 255)
24 |
25 | test_datagen = ImageDataGenerator(rescale=1. / 255)
26 |
27 | train_generator = train_datagen.flow_from_directory(
28 | train_data_path,
29 | target_size=(img_height, img_width),
30 | batch_size=batch_size,
31 | class_mode='binary')
32 |
33 | validation_generator = test_datagen.flow_from_directory(
34 | validation_data_path,
35 | target_size=(img_height, img_width),
36 | batch_size=batch_size,
37 | class_mode='binary')
38 |
39 | log_dir = './tf-log/'
40 | tb_cb = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0)
41 | cbks = [tb_cb]
42 |
43 | model.fit_generator(
44 | train_generator,
45 | epochs=epochs,
46 | validation_data=validation_generator,
47 | callbacks=cbks,
48 | shuffle=True)
49 |
50 | target_dir = './meso/'
51 | if not os.path.exists(target_dir):
52 | os.mkdir(target_dir)
53 | model.save('./meso/model.h5')
54 | model.save_weights('./meso/weights.h5')
55 |
56 | # Calculate execution time
57 | end = time.time()
58 | dur = end - start
59 |
60 | if dur < 60:
61 | print("Execution Time:", dur, "seconds")
62 | elif dur > 60 and dur < 3600:
63 | dur = dur / 60
64 | print("Execution Time:", dur, "minutes")
65 | else:
66 | dur = dur / (60 * 60)
67 | print("Execution Time:", dur, "hours")
68 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/models/__init__.py
--------------------------------------------------------------------------------
/models/convGRU.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | from torch.autograd import Variable
5 |
6 |
7 | class ConvGRUCell(nn.Module):
8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias, dtype):
9 | """
10 | Initialize the ConvLSTM cell
11 | :param input_size: (int, int)
12 | Height and width of input tensor as (height, width).
13 | :param input_dim: int
14 | Number of channels of input tensor.
15 | :param hidden_dim: int
16 | Number of channels of hidden state.
17 | :param kernel_size: (int, int)
18 | Size of the convolutional kernel.
19 | :param bias: bool
20 | Whether or not to add the bias.
21 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor
22 | Whether or not to use cuda.
23 | """
24 | super(ConvGRUCell, self).__init__()
25 | self.height, self.width = input_size
26 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
27 | self.hidden_dim = hidden_dim
28 | self.bias = bias
29 | self.dtype = dtype
30 |
31 | self.conv_gates = nn.Conv2d(in_channels=input_dim + hidden_dim,
32 | out_channels=2*self.hidden_dim, # for update_gate,reset_gate respectively
33 | kernel_size=kernel_size,
34 | padding=self.padding,
35 | bias=self.bias)
36 |
37 | self.conv_can = nn.Conv2d(in_channels=input_dim+hidden_dim,
38 | out_channels=self.hidden_dim, # for candidate neural memory
39 | kernel_size=kernel_size,
40 | padding=self.padding,
41 | bias=self.bias)
42 |
43 | def init_hidden(self, batch_size):
44 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).type(self.dtype))
45 |
46 | def forward(self, input_tensor, h_cur):
47 | """
48 |
49 | :param self:
50 | :param input_tensor: (b, c, h, w)
51 | input is actually the target_model
52 | :param h_cur: (b, c_hidden, h, w)
53 | current hidden and cell states respectively
54 | :return: h_next,
55 | next hidden state
56 | """
57 | combined = torch.cat([input_tensor, h_cur], dim=1)
58 | combined_conv = self.conv_gates(combined)
59 |
60 | gamma, beta = torch.split(combined_conv, self.hidden_dim, dim=1)
61 | reset_gate = torch.sigmoid(gamma)
62 | update_gate = torch.sigmoid(beta)
63 |
64 | combined = torch.cat([input_tensor, reset_gate*h_cur], dim=1)
65 | cc_cnm = self.conv_can(combined)
66 | cnm = torch.tanh(cc_cnm)
67 |
68 | h_next = (1 - update_gate) * h_cur + update_gate * cnm
69 | return h_next
70 |
71 |
72 | class ConvGRU(nn.Module):
73 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
74 | dtype=torch.cuda.FloatTensor, batch_first=False, bias=True, return_all_layers=False):
75 | """
76 |
77 | :param input_size: (int, int)
78 | Height and width of input tensor as (height, width).
79 | :param input_dim: int e.g. 256
80 | Number of channels of input tensor.
81 | :param hidden_dim: int e.g. 1024
82 | Number of channels of hidden state.
83 | :param kernel_size: (int, int)
84 | Size of the convolutional kernel.
85 | :param num_layers: int
86 | Number of ConvLSTM layers
87 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor
88 | Whether or not to use cuda.
89 | :param alexnet_path: str
90 | pretrained alexnet parameters
91 | :param batch_first: bool
92 | if the first position of array is batch or not
93 | :param bias: bool
94 | Whether or not to add the bias.
95 | :param return_all_layers: bool
96 | if return hidden and cell states for all layers
97 | """
98 | super(ConvGRU, self).__init__()
99 |
100 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
101 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
102 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
103 | if not len(kernel_size) == len(hidden_dim) == num_layers:
104 | raise ValueError('Inconsistent list length.')
105 |
106 | self.height, self.width = input_size
107 | self.input_dim = input_dim
108 | self.hidden_dim = hidden_dim
109 | self.kernel_size = kernel_size
110 | self.dtype = dtype
111 | self.num_layers = num_layers
112 | self.batch_first = batch_first
113 | self.bias = bias
114 | self.return_all_layers = return_all_layers
115 |
116 | cell_list = []
117 | for i in range(0, self.num_layers):
118 | cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]
119 | cell_list.append(ConvGRUCell(input_size=(self.height, self.width),
120 | input_dim=cur_input_dim,
121 | hidden_dim=self.hidden_dim[i],
122 | kernel_size=self.kernel_size[i],
123 | bias=self.bias,
124 | dtype=self.dtype))
125 |
126 | # convert python list to pytorch module
127 | self.cell_list = nn.ModuleList(cell_list)
128 |
129 | def forward(self, input_tensor, hidden_state=None):
130 | """
131 |
132 | :param input_tensor: (b, t, c, h, w) or (t,b,c,h,w) depends on if batch first or not
133 | extracted features from alexnet
134 | :param hidden_state:
135 | :return: layer_output_list, last_state_list
136 | """
137 | if not self.batch_first:
138 | # (t, b, c, h, w) -> (b, t, c, h, w)
139 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
140 |
141 | # Implement stateful ConvLSTM
142 | if hidden_state is not None:
143 | raise NotImplementedError()
144 | else:
145 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0))
146 |
147 | layer_output_list = []
148 | last_state_list = []
149 |
150 | seq_len = input_tensor.size(1)
151 | cur_layer_input = input_tensor
152 |
153 | for layer_idx in range(self.num_layers):
154 | h = hidden_state[layer_idx]
155 | output_inner = []
156 | for t in range(seq_len):
157 | # input current hidden and cell state then compute the next hidden and cell state through ConvLSTMCell forward function
158 | h = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], # (b,t,c,h,w)
159 | h_cur=h)
160 | output_inner.append(h)
161 |
162 | layer_output = torch.stack(output_inner, dim=1)
163 | cur_layer_input = layer_output
164 |
165 | layer_output_list.append(layer_output)
166 | last_state_list.append([h])
167 |
168 | if not self.return_all_layers:
169 | layer_output_list = layer_output_list[-1:]
170 | last_state_list = last_state_list[-1:]
171 |
172 | return layer_output_list, last_state_list
173 |
174 | def _init_hidden(self, batch_size):
175 | init_states = []
176 | for i in range(self.num_layers):
177 | init_states.append(self.cell_list[i].init_hidden(batch_size))
178 | return init_states
179 |
180 | @staticmethod
181 | def _check_kernel_size_consistency(kernel_size):
182 | if not (isinstance(kernel_size, tuple) or
183 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
184 | raise ValueError('`kernel_size` must be tuple or list of tuples')
185 |
186 | @staticmethod
187 | def _extend_for_multilayer(param, num_layers):
188 | if not isinstance(param, list):
189 | param = [param] * num_layers
190 | return param
191 |
192 |
193 | if __name__ == '__main__':
194 | # set CUDA device
195 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
196 |
197 | # detect if CUDA is available or not
198 | use_gpu = torch.cuda.is_available()
199 | if use_gpu:
200 | dtype = torch.cuda.FloatTensor # computation in GPU
201 | else:
202 | dtype = torch.FloatTensor
203 |
204 | height = width = 6
205 | channels = 256
206 | hidden_dim = [32, 64]
207 | kernel_size = (3,3) # kernel size for two stacked hidden layer
208 | num_layers = 2 # number of stacked hidden layer
209 | model = ConvGRU(input_size=(height, width),
210 | input_dim=channels,
211 | hidden_dim=hidden_dim,
212 | kernel_size=kernel_size,
213 | num_layers=num_layers,
214 | dtype=dtype,
215 | batch_first=True,
216 | bias = True,
217 | return_all_layers = False)
218 |
219 | batch_size = 1
220 | time_steps = 1
221 | input_tensor = torch.rand(batch_size, time_steps, channels, height, width) # (b,t,c,h,w)
222 | layer_output_list, last_state_list = model(input_tensor)
--------------------------------------------------------------------------------
/models/convlstm.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class ConvLSTMCell(nn.Module):
6 |
7 | def __init__(self, input_dim, hidden_dim, kernel_size, bias):
8 | """
9 | Initialize ConvLSTM cell.
10 |
11 | Parameters
12 | ----------
13 | input_dim: int
14 | Number of channels of input tensor.
15 | hidden_dim: int
16 | Number of channels of hidden state.
17 | kernel_size: (int, int)
18 | Size of the convolutional kernel.
19 | bias: bool
20 | Whether or not to add the bias.
21 | """
22 |
23 | super(ConvLSTMCell, self).__init__()
24 |
25 | self.input_dim = input_dim
26 | self.hidden_dim = hidden_dim
27 |
28 | self.kernel_size = kernel_size
29 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
30 | self.bias = bias
31 |
32 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
33 | out_channels=4 * self.hidden_dim,
34 | kernel_size=self.kernel_size,
35 | padding=self.padding,
36 | bias=self.bias)
37 |
38 | def forward(self, input_tensor, cur_state):
39 | h_cur, c_cur = cur_state
40 |
41 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
42 |
43 | combined_conv = self.conv(combined)
44 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
45 | i = torch.sigmoid(cc_i)
46 | f = torch.sigmoid(cc_f)
47 | o = torch.sigmoid(cc_o)
48 | g = torch.tanh(cc_g)
49 |
50 | c_next = f * c_cur + i * g
51 | h_next = o * torch.tanh(c_next)
52 |
53 | return h_next, c_next
54 |
55 | def init_hidden(self, batch_size, image_size):
56 | height, width = image_size
57 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
58 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
59 |
60 |
61 | class ConvLSTM(nn.Module):
62 |
63 | """
64 |
65 | Parameters:
66 | input_dim: Number of channels in input
67 | hidden_dim: Number of hidden channels
68 | kernel_size: Size of kernel in convolutions
69 | num_layers: Number of LSTM layers stacked on each other
70 | batch_first: Whether or not dimension 0 is the batch or not
71 | bias: Bias or no bias in Convolution
72 | return_all_layers: Return the list of computations for all layers
73 | Note: Will do same padding.
74 |
75 | Input:
76 | A tensor of size B, T, C, H, W or T, B, C, H, W
77 | Output:
78 | A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
79 | 0 - layer_output_list is the list of lists of length T of each output
80 | 1 - last_state_list is the list of last states
81 | each element of the list is a tuple (h, c) for hidden state and memory
82 | Example:
83 | >> x = torch.rand((32, 10, 64, 128, 128))
84 | >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
85 | >> _, last_states = convlstm(x)
86 | >> h = last_states[0][0] # 0 for layer index, 0 for h index
87 | """
88 |
89 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
90 | batch_first=False, bias=True, return_all_layers=False):
91 | super(ConvLSTM, self).__init__()
92 |
93 | self._check_kernel_size_consistency(kernel_size)
94 |
95 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
96 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
97 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
98 | if not len(kernel_size) == len(hidden_dim) == num_layers:
99 | raise ValueError('Inconsistent list length.')
100 |
101 | self.input_dim = input_dim
102 | self.hidden_dim = hidden_dim
103 | self.kernel_size = kernel_size
104 | self.num_layers = num_layers
105 | self.batch_first = batch_first
106 | self.bias = bias
107 | self.return_all_layers = return_all_layers
108 |
109 | cell_list = []
110 | for i in range(0, self.num_layers):
111 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
112 |
113 | cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
114 | hidden_dim=self.hidden_dim[i],
115 | kernel_size=self.kernel_size[i],
116 | bias=self.bias))
117 |
118 | self.cell_list = nn.ModuleList(cell_list)
119 |
120 | def forward(self, input_tensor, hidden_state=None):
121 | """
122 |
123 | Parameters
124 | ----------
125 | input_tensor: todo
126 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
127 | hidden_state: todo
128 | None. todo implement stateful
129 |
130 | Returns
131 | -------
132 | last_state_list, layer_output
133 | """
134 | if not self.batch_first:
135 | # (t, b, c, h, w) -> (b, t, c, h, w)
136 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
137 |
138 | b, _, _, h, w = input_tensor.size()
139 |
140 | # Implement stateful ConvLSTM
141 | if hidden_state is not None:
142 | raise NotImplementedError()
143 | else:
144 | # Since the init is done in forward. Can send image size here
145 | hidden_state = self._init_hidden(batch_size=b,
146 | image_size=(h, w))
147 |
148 | layer_output_list = []
149 | last_state_list = []
150 |
151 | seq_len = input_tensor.size(1)
152 | cur_layer_input = input_tensor
153 |
154 | for layer_idx in range(self.num_layers):
155 |
156 | h, c = hidden_state[layer_idx]
157 | output_inner = []
158 | for t in range(seq_len):
159 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
160 | cur_state=[h, c])
161 | output_inner.append(h)
162 |
163 | layer_output = torch.stack(output_inner, dim=1)
164 | cur_layer_input = layer_output
165 |
166 | layer_output_list.append(layer_output)
167 | last_state_list.append([h, c])
168 |
169 | if not self.return_all_layers:
170 | layer_output_list = layer_output_list[-1:]
171 | last_state_list = last_state_list[-1:]
172 |
173 | return layer_output_list, last_state_list
174 |
175 | def _init_hidden(self, batch_size, image_size):
176 | init_states = []
177 | for i in range(self.num_layers):
178 | init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
179 | return init_states
180 |
181 | @staticmethod
182 | def _check_kernel_size_consistency(kernel_size):
183 | if not (isinstance(kernel_size, tuple) or
184 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
185 | raise ValueError('`kernel_size` must be tuple or list of tuples')
186 |
187 | @staticmethod
188 | def _extend_for_multilayer(param, num_layers):
189 | if not isinstance(param, list):
190 | param = [param] * num_layers
191 | return param
192 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from torchvision import models as Models
5 |
6 | from dct.imagenet.gate import GateModule192
7 | from dct.utils import kaiming_init, constant_init
8 | from models.convlstm import ConvLSTM
9 | from models.convGRU import ConvGRU
10 | from models import resnet
11 |
12 | from dct.imagenet.resnet import *
13 |
14 |
15 | class Baseline(nn.Module):
16 |
17 | def __init__(self, use_gru=False, bi_branch=False, rnn_hidden_layers=3, rnn_hidden_nodes=256,
18 | num_classes=1, bidirectional=False, dct=False, inputgate=False):
19 |
20 | super(Baseline, self).__init__()
21 |
22 | self.rnn_hidden_layers = rnn_hidden_layers
23 | self.rnn_hidden_nodes = rnn_hidden_nodes
24 | self.num_classes = num_classes
25 | self.bi_branch = bi_branch
26 | self.inputgate = inputgate
27 |
28 | if not dct:
29 | pretrained_cnn = Models.resnet50(pretrained=True)
30 | cnn_layers = list(pretrained_cnn.children())[:-1]
31 | else:
32 | pretrained_cnn = ResNetDCT_Upscaled_Static(channels=192, pretrained=True)
33 | cnn_layers = list(pretrained_cnn.children())[:-2]
34 |
35 | self.cnn = nn.Sequential(*cnn_layers)
36 | rnn_params = {
37 | 'input_size': pretrained_cnn.fc.in_features,
38 | 'hidden_size': self.rnn_hidden_nodes,
39 | 'num_layers': self.rnn_hidden_layers,
40 | 'batch_first': True,
41 | 'bidirectional': bidirectional
42 | }
43 |
44 | if bidirectional:
45 | fc_in = 2 * rnn_hidden_nodes
46 | else:
47 | fc_in = rnn_hidden_nodes
48 |
49 | self.rnn = (nn.GRU if use_gru else nn.LSTM)(**rnn_params)
50 |
51 | self.fc_cnn = nn.Linear(fc_in, num_classes)
52 |
53 | self.global_pool = nn.AdaptiveAvgPool2d(16)
54 |
55 | self.fc_rnn = nn.Linear(256, self.num_classes)
56 |
57 | if inputgate:
58 | self.inp_GM = GateModule192()
59 | self._initialize_weights()
60 |
61 | def forward(self, x_3d):
62 |
63 | cnn_embedding_out = []
64 | cnn_pred = []
65 | frame_num = x_3d.size(1)
66 | gates = []
67 |
68 | for t in range(frame_num):
69 | if self.inputgate:
70 | x, gate_activations = self.inp_GM(x_3d[:, t, :, :, :])
71 | gates.append(gate_activations)
72 | x = self.cnn(x_3d[:, t, :, :, :])
73 | x = torch.flatten(x, start_dim=1)
74 | cnn_embedding_out.append(x)
75 |
76 | cnn_embedding_out = torch.stack(cnn_embedding_out, dim=0).transpose(0, 1)
77 |
78 | self.rnn.flatten_parameters()
79 | rnn_out, _ = self.rnn(cnn_embedding_out, None)
80 |
81 | if self.bi_branch:
82 | for t in range(rnn_out.size(1)):
83 | x = rnn_out[:, t, :]
84 | x = self.fc_cnn(x)
85 | cnn_pred.append(x)
86 | cnn_pred = torch.stack(cnn_pred, dim=0).transpose(0, 1)
87 |
88 | x = self.global_pool(rnn_out)
89 | x = torch.flatten(x, start_dim=1)
90 | x = self.fc_rnn(x)
91 |
92 | if self.inputgate:
93 | if self.bi_branch:
94 | return x, cnn_pred.reshape(-1, self.num_classes), torch.stack(gates, dim=0).view(-1, 192, 1)
95 | else:
96 | return x, gates
97 | else:
98 | if self.bi_branch:
99 | return x, cnn_pred.reshape(-1, self.num_classes)
100 | else:
101 | return x
102 |
103 | def _initialize_weights(self):
104 | for name, m in self.named_modules():
105 | if 'inp_gate_l' in str(name):
106 | m.weight.data.normal_(0, 0.001)
107 | m.bias.data[::2].fill_(0.1)
108 | m.bias.data[1::2].fill_(2)
109 | elif 'inp_gate' in str(name):
110 | if isinstance(m, nn.Conv2d):
111 | kaiming_init(m)
112 | elif isinstance(m, nn.BatchNorm2d):
113 | constant_init(m, 1)
114 |
115 |
116 | class CNN(nn.Module):
117 | def __init__(self, bi_branch=False, num_classes=2):
118 | super(CNN, self).__init__()
119 |
120 | self.num_classes = num_classes
121 |
122 | # 使用resnet预训练模型来提取特征,去掉最后一层分类器
123 | pretrained_cnn = Models.resnet50(pretrained=True)
124 | cnn_layers = list(pretrained_cnn.children())[:-1]
125 |
126 | # 把resnet的最后一层fc层去掉,用来提取特征
127 | self.cnn = nn.Sequential(*cnn_layers)
128 |
129 | self.global_pool = nn.AdaptiveAvgPool1d(1)
130 |
131 | self.cnn_out = nn.Sequential(
132 | nn.Linear(2048, 2)
133 | )
134 |
135 | def forward(self, x_3d):
136 | """
137 | 输入的是T帧图像,shape = (batch_size, t, h, w, 3)
138 | """
139 | cnn_embedding_out = []
140 | for t in range(x_3d.size(1)):
141 | # 使用cnn提取特征
142 | x = self.cnn(x_3d[:, t, :, :, :])
143 | x = torch.flatten(x, start_dim=1)
144 | x = self.cnn_out(x)
145 | cnn_embedding_out.append(x)
146 | cnn_embedding_out = torch.stack(cnn_embedding_out, dim=0).transpose(0, 1)
147 |
148 | x = self.global_pool(cnn_embedding_out)
149 | x = torch.flatten(x, start_dim=1)
150 |
151 | return x
152 |
153 |
154 | class cRNN(nn.Module):
155 | def __init__(self, use_gru=False, bi_branch=False, num_classes=2):
156 | super(cRNN, self).__init__()
157 |
158 | self.num_classes = num_classes
159 | self.use_gru = use_gru
160 |
161 | # 使用resnet预训练模型来提取特征,去掉最后一层分类器
162 | pretrained_cnn = Models.resnet50(pretrained=True)
163 | cnn_layers = list(pretrained_cnn.children())[:-2]
164 |
165 | # 把resnet的最后一层fc层去掉,用来提取特征
166 | self.cnn = nn.Sequential(*cnn_layers)
167 |
168 | cRNN_params = {
169 | 'input_dim': 2048,
170 | 'hidden_dim': [256, 256, 512],
171 | 'kernel_size': (1, 1),
172 | 'num_layers': 3,
173 | 'batch_first': True
174 | } if not use_gru else {
175 | 'input_size': (2, 2),
176 | 'input_dim': 2048,
177 | 'hidden_dim': [256, 256, 512],
178 | 'kernel_size': (1, 1),
179 | 'num_layers': 3,
180 | 'batch_first': True
181 | }
182 |
183 | self.cRNN = (ConvGRU if use_gru else ConvLSTM)(**cRNN_params)
184 |
185 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
186 |
187 | self.fc = nn.Sequential(
188 | nn.Linear(512, self.num_classes)
189 | )
190 |
191 | def forward(self, x_3d):
192 | cnn_embedding_out = []
193 | for t in range(x_3d.size(1)):
194 | # 使用cnn提取特征
195 | x = self.cnn(x_3d[:, t, :, :, :])
196 | cnn_embedding_out.append(x)
197 |
198 | x = torch.stack(cnn_embedding_out, dim=0).transpose(0, 1)
199 |
200 | _, outputs = self.cRNN(x)
201 | x = outputs[0][0] if self.use_gru else outputs[0][1]
202 |
203 | x = self.global_pool(x)
204 | x = torch.flatten(x, 1)
205 | x = self.fc(x)
206 |
207 | return x
208 |
209 |
210 | def get_resnet_3d(num_classes=2, model_depth=10, shortcut_type='B', sample_size=112, sample_duration=16):
211 | assert model_depth in [10, 18, 34, 50, 101, 152, 200]
212 |
213 | if model_depth == 10:
214 | model = resnet.resnet10(
215 | num_classes=num_classes,
216 | shortcut_type=shortcut_type,
217 | sample_size=sample_size,
218 | sample_duration=sample_duration)
219 | elif model_depth == 18:
220 | model = resnet.resnet18(
221 | num_classes=num_classes,
222 | shortcut_type=shortcut_type,
223 | sample_size=sample_size,
224 | sample_duration=sample_duration)
225 | elif model_depth == 34:
226 | model = resnet.resnet34(
227 | num_classes=num_classes,
228 | shortcut_type=shortcut_type,
229 | sample_size=sample_size,
230 | sample_duration=sample_duration)
231 | elif model_depth == 50:
232 | model = resnet.resnet50(
233 | num_classes=num_classes,
234 | shortcut_type=shortcut_type,
235 | sample_size=sample_size,
236 | sample_duration=sample_duration)
237 | elif model_depth == 101:
238 | model = resnet.resnet101(
239 | num_classes=num_classes,
240 | shortcut_type=shortcut_type,
241 | sample_size=sample_size,
242 | sample_duration=sample_duration)
243 | elif model_depth == 152:
244 | model = resnet.resnet152(
245 | num_classes=num_classes,
246 | shortcut_type=shortcut_type,
247 | sample_size=sample_size,
248 | sample_duration=sample_duration)
249 | else:
250 | model = resnet.resnet200(
251 | num_classes=num_classes,
252 | shortcut_type=shortcut_type,
253 | sample_size=sample_size,
254 | sample_duration=sample_duration)
255 |
256 | return model
257 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import math
6 | from functools import partial
7 |
8 | __all__ = [
9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnet200'
11 | ]
12 |
13 |
14 | def conv3x3x3(in_planes, out_planes, stride=1):
15 | # 3x3x3 convolution with padding
16 | return nn.Conv3d(
17 | in_planes,
18 | out_planes,
19 | kernel_size=3,
20 | stride=stride,
21 | padding=1,
22 | bias=False)
23 |
24 |
25 | def downsample_basic_block(x, planes, stride):
26 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
27 | zero_pads = torch.Tensor(
28 | out.size(0), planes - out.size(1), out.size(2), out.size(3),
29 | out.size(4)).zero_()
30 | if isinstance(out.data, torch.cuda.FloatTensor):
31 | zero_pads = zero_pads.cuda()
32 |
33 | out = Variable(torch.cat([out.data, zero_pads], dim=1))
34 |
35 | return out
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion = 1
40 |
41 | def __init__(self, inplanes, planes, stride=1, downsample=None):
42 | super(BasicBlock, self).__init__()
43 | self.conv1 = conv3x3x3(inplanes, planes, stride)
44 | self.bn1 = nn.BatchNorm3d(planes)
45 | self.relu = nn.ReLU(inplace=True)
46 | self.conv2 = conv3x3x3(planes, planes)
47 | self.bn2 = nn.BatchNorm3d(planes)
48 | self.downsample = downsample
49 | self.stride = stride
50 |
51 | def forward(self, x):
52 | residual = x
53 |
54 | out = self.conv1(x)
55 | out = self.bn1(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class Bottleneck(nn.Module):
71 | expansion = 4
72 |
73 | def __init__(self, inplanes, planes, stride=1, downsample=None):
74 | super(Bottleneck, self).__init__()
75 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
76 | self.bn1 = nn.BatchNorm3d(planes)
77 | self.conv2 = nn.Conv3d(
78 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
79 | self.bn2 = nn.BatchNorm3d(planes)
80 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
81 | self.bn3 = nn.BatchNorm3d(planes * 4)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.downsample = downsample
84 | self.stride = stride
85 |
86 | def forward(self, x):
87 | residual = x
88 |
89 | out = self.conv1(x)
90 | out = self.bn1(out)
91 | out = self.relu(out)
92 |
93 | out = self.conv2(out)
94 | out = self.bn2(out)
95 | out = self.relu(out)
96 |
97 | out = self.conv3(out)
98 | out = self.bn3(out)
99 |
100 | if self.downsample is not None:
101 | residual = self.downsample(x)
102 |
103 | out += residual
104 | out = self.relu(out)
105 |
106 | return out
107 |
108 |
109 | class ResNet(nn.Module):
110 |
111 | def __init__(self,
112 | block,
113 | layers,
114 | sample_size,
115 | sample_duration,
116 | shortcut_type='B',
117 | num_classes=400):
118 | self.inplanes = 64
119 | super(ResNet, self).__init__()
120 | self.conv1 = nn.Conv3d(
121 | 3,
122 | 64,
123 | kernel_size=7,
124 | stride=(1, 2, 2),
125 | padding=(3, 3, 3),
126 | bias=False)
127 | self.bn1 = nn.BatchNorm3d(64)
128 | self.relu = nn.ReLU(inplace=True)
129 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
130 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
131 | self.layer2 = self._make_layer(
132 | block, 128, layers[1], shortcut_type, stride=2)
133 | self.layer3 = self._make_layer(
134 | block, 256, layers[2], shortcut_type, stride=2)
135 | self.layer4 = self._make_layer(
136 | block, 512, layers[3], shortcut_type, stride=2)
137 | last_duration = int(math.ceil(sample_duration / 16))
138 | last_size = int(math.ceil(sample_size / 32))
139 | self.avgpool = nn.AvgPool3d(
140 | (last_duration, last_size, last_size), stride=1)
141 | self.fc = nn.Linear(512 * block.expansion, num_classes)
142 |
143 | for m in self.modules():
144 | if isinstance(m, nn.Conv3d):
145 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
146 | elif isinstance(m, nn.BatchNorm3d):
147 | m.weight.data.fill_(1)
148 | m.bias.data.zero_()
149 |
150 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
151 | downsample = None
152 | if stride != 1 or self.inplanes != planes * block.expansion:
153 | if shortcut_type == 'A':
154 | downsample = partial(
155 | downsample_basic_block,
156 | planes=planes * block.expansion,
157 | stride=stride)
158 | else:
159 | downsample = nn.Sequential(
160 | nn.Conv3d(
161 | self.inplanes,
162 | planes * block.expansion,
163 | kernel_size=1,
164 | stride=stride,
165 | bias=False), nn.BatchNorm3d(planes * block.expansion))
166 |
167 | layers = []
168 | layers.append(block(self.inplanes, planes, stride, downsample))
169 | self.inplanes = planes * block.expansion
170 | for i in range(1, blocks):
171 | layers.append(block(self.inplanes, planes))
172 |
173 | return nn.Sequential(*layers)
174 |
175 | def forward(self, x):
176 | x = self.conv1(x)
177 | x = self.bn1(x)
178 | x = self.relu(x)
179 | x = self.maxpool(x)
180 |
181 | x = self.layer1(x)
182 | x = self.layer2(x)
183 | x = self.layer3(x)
184 | x = self.layer4(x)
185 |
186 | x = self.avgpool(x)
187 |
188 | x = x.view(x.size(0), -1)
189 | x = self.fc(x)
190 |
191 | return x
192 |
193 |
194 | def get_fine_tuning_parameters(model, ft_begin_index):
195 | if ft_begin_index == 0:
196 | return model.parameters()
197 |
198 | ft_module_names = []
199 | for i in range(ft_begin_index, 5):
200 | ft_module_names.append('layer{}'.format(i))
201 | ft_module_names.append('fc')
202 |
203 | parameters = []
204 | for k, v in model.named_parameters():
205 | for ft_module in ft_module_names:
206 | if ft_module in k:
207 | parameters.append({'params': v})
208 | break
209 | else:
210 | parameters.append({'params': v, 'lr': 0.0})
211 |
212 | return parameters
213 |
214 |
215 | def resnet10(**kwargs):
216 | """Constructs a ResNet-18 model.
217 | """
218 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
219 | return model
220 |
221 |
222 | def resnet18(**kwargs):
223 | """Constructs a ResNet-18 model.
224 | """
225 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
226 | return model
227 |
228 |
229 | def resnet34(**kwargs):
230 | """Constructs a ResNet-34 model.
231 | """
232 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
233 | return model
234 |
235 |
236 | def resnet50(**kwargs):
237 | """Constructs a ResNet-50 model.
238 | """
239 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
240 | return model
241 |
242 |
243 | def resnet101(**kwargs):
244 | """Constructs a ResNet-101 model.
245 | """
246 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
247 | return model
248 |
249 |
250 | def resnet152(**kwargs):
251 | """Constructs a ResNet-101 model.
252 | """
253 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
254 | return model
255 |
256 |
257 | def resnet200(**kwargs):
258 | """Constructs a ResNet-101 model.
259 | """
260 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
261 | return model
262 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/utils/__init__.py
--------------------------------------------------------------------------------
/utils/auccur.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.data import DataLoader
4 | from tqdm import tqdm
5 | import matplotlib.pyplot as plt
6 | import pandas
7 | import os
8 | import argparse
9 |
10 | from models.model import cRNN, get_resnet_3d, CNN, Baseline
11 | from dataloader import FrameDataset, Dataset
12 | import config
13 | from sklearn.metrics import roc_curve, auc
14 | from sklearn.metrics import accuracy_score, roc_auc_score
15 |
16 |
17 | def test(model: nn.Sequential, test_loader: torch.utils.data.DataLoader, model_type, device):
18 | model.eval()
19 |
20 | print('Size of Test Set: ', len(test_loader.dataset))
21 |
22 | # 准备在测试集上验证模型性能
23 | test_loss = 0
24 | y_gd = []
25 | frame_y_gd = []
26 | y_pred = []
27 | frame_y_pred = []
28 |
29 | with torch.no_grad():
30 | if config.net_params.get('our'):
31 | for X, y in tqdm(test_loader, desc='Validating plus frame level'):
32 | X, y = X.to(device), y.to(device)
33 | frame_y = y.view(-1, 1)
34 | frame_y = frame_y.repeat(1, 300)
35 | frame_y = frame_y.flatten()
36 | y_, cnn_y = model(X)
37 |
38 | y_ = y_.argmax(dim=1)
39 | frame_y_ = cnn_y.argmax(dim=1)
40 |
41 | y_gd += y.cpu().numpy().tolist()
42 | y_pred += y_.cpu().numpy().tolist()
43 | frame_y_gd += frame_y.cpu().numpy().tolist()
44 | frame_y_pred += frame_y_.cpu().numpy().tolist()
45 |
46 | test_video_acc = accuracy_score(y_gd, y_pred)
47 | test_video_auc = roc_auc_score(y_gd, y_pred)
48 | test_frame_acc = accuracy_score(frame_y_gd, frame_y_pred)
49 | test_frame_auc = roc_auc_score(frame_y_gd, frame_y_pred)
50 | print('Test video avg loss: %0.4f, acc: %0.2f, auc: %0.2f\n' % (
51 | test_loss, test_video_acc, test_video_auc))
52 | print('Test frame avg loss: %0.4f, acc: %0.2f, auc: %0.2f\n' % (
53 | test_loss, test_frame_acc, test_frame_auc))
54 |
55 |
56 | else:
57 | for X, y in tqdm(test_loader, desc='Validating plus frame level'):
58 | X, y = X.to(device), y.to(device)
59 | cnn_y = model(X)
60 | frame_y_ = cnn_y.argmax(dim=1)
61 | frame_y_gd += y.cpu().numpy().tolist()
62 | frame_y_pred += frame_y_.cpu().numpy().tolist()
63 | test_frame_acc = accuracy_score(frame_y_gd, frame_y_pred)
64 | test_frame_auc = roc_auc_score(frame_y_gd, frame_y_pred)
65 | print('Test frame avg loss: %0.4f, acc: %0.2f, auc: %0.2f\n' % (test_loss, test_frame_acc, test_frame_auc))
66 |
67 | return frame_y_gd, frame_y_pred
68 |
69 |
70 | def parse_args():
71 | parser = argparse.ArgumentParser(usage='python3 main.py -i path/to/data -r path/to/checkpoint')
72 | parser.add_argument('-i', '--data_path', help='path to your datasets', default='/data2/guesthome/wenbop/ffdf_c40')
73 | # parser.add_argument('-i', '--data_path', help='path to your datasets', default='/Users/pu/Desktop/dataset_dlib')
74 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint',
75 | default='/data2/guesthome/wenbop/modules/ff/bi-model_type-baseline_gru_ep-19.pth')
76 | # parser.add_argument('-g', '--gpu', help='visible gpu ids', default='4,5,7')
77 | parser.add_argument('-g', '--gpu', help='visible gpu ids', default='0,1,2,3')
78 | args = parser.parse_args()
79 | return args
80 |
81 |
82 | def draw_auc():
83 | fpr = dict()
84 | tpr = dict()
85 | roc_auc = dict()
86 |
87 | args = parse_args()
88 | data_path = args.data_path
89 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
90 | raw_data = pandas.read_csv(os.path.join(data_path, '%s.csv' % 'test'))
91 | dataloader = DataLoader(Dataset(raw_data.to_numpy()), **config.dataset_params)
92 | use_cuda = torch.cuda.is_available()
93 | device = torch.device('cuda' if use_cuda else 'cpu')
94 | model = Baseline()
95 | device_count = torch.cuda.device_count()
96 | if device_count > 1:
97 | print('使用{}个GPU训练'.format(device_count))
98 | model = nn.DataParallel(model)
99 | model.to(device)
100 | ckpt = {}
101 | # 从断点继续训练
102 | if args.restore_from is not None:
103 | ckpt = torch.load(args.restore_from)
104 | # model.load_state_dict(ckpt['net'])
105 | model.load_state_dict(ckpt['model_state_dict'])
106 | print('Model is loaded from %s' % (args.restore_from))
107 |
108 | y_test, y_score = test(model, dataloader, 'baseline', device)
109 |
110 | for i in range(2):
111 | fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
112 | roc_auc[i] = auc(fpr[i], tpr[i])
113 | plt.figure()
114 | lw = 2
115 | plt.plot(fpr[0], tpr[0], color='darkorange',
116 | lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[0])
117 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
118 | plt.xlim([0.0, 1.0])
119 | plt.ylim([0.0, 1.05])
120 | plt.xlabel('False Positive Rate')
121 | plt.ylabel('True Positive Rate')
122 | plt.title('Receiver operating characteristic example')
123 | plt.legend(loc="lower right")
124 | plt.show()
125 |
126 |
127 | draw_auc()
128 |
--------------------------------------------------------------------------------
/utils/aucloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class AUCLoss(torch.nn.Module):
6 | def __init__(self, device, gamma=0.15, alpha=0.6, p=2):
7 | super().__init__()
8 | self.gamma = gamma
9 | self.alpha = alpha
10 | self.p = p
11 | self.device = device
12 |
13 | def forward(self, y_pred, y_true):
14 | pred = torch.sigmoid(y_pred)
15 | pos = pred[torch.where(y_true == 0)]
16 | neg = pred[torch.where(y_true == 1)]
17 | pos = torch.unsqueeze(pos, 0)
18 | neg = torch.unsqueeze(neg, 1)
19 | diff = torch.zeros_like(pos * neg, device=self.device) + pos - neg - self.gamma
20 | masked = diff[torch.where(diff < 0.0)]
21 | auc = torch.mean(torch.pow(-masked, self.p))
22 | bce = F.binary_cross_entropy_with_logits(y_pred, y_true)
23 | if masked.shape[0] == 0:
24 | loss = bce
25 | else:
26 | loss = self.alpha * bce + (1 - self.alpha) * auc
27 | return loss
28 |
--------------------------------------------------------------------------------
/utils/cam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | from torch.autograd import Function
4 | from torchvision import models
5 | from torchvision import utils
6 | import cv2
7 | import sys
8 | from collections import OrderedDict
9 | import numpy as np
10 | import argparse
11 | import os
12 | import torch.nn as nn
13 | from models.model import Baseline
14 | import config
15 |
16 |
17 | i=0##testing in what
18 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
19 | use_cuda = torch.cuda.is_available()
20 | device = torch.device('cuda' if use_cuda else 'cpu')
21 | resnet = Baseline(**config.net_params)
22 | device_count = torch.cuda.device_count()
23 | print('使用4个GPU')
24 | model = nn.DataParallel(resnet)
25 | resnet.to(device)
26 | ckpt = torch.load('/data2/guesthome/wenbop/modules/ff/bi-model_type-baseline_gru_ep-19.pth')
27 | #model.load_state_dict(ckpt['net'])
28 | resnet.load_state_dict(ckpt['model_state_dict'])
29 |
30 | # resnet = models.resnet50(pretrained=True)#这里单独加载一个包含全连接层的resnet50模型
31 | image = []
32 | class FeatureExtractor():
33 | """ Class for extracting activations and
34 | registering gradients from targetted intermediate layers """
35 | def __init__(self, model, target_layers):
36 | self.model = model
37 | self.target_layers = target_layers
38 | self.gradients = []
39 |
40 | def save_gradient(self, grad):
41 | self.gradients.append(grad)
42 |
43 | def __call__(self, x):
44 | outputs = []
45 | self.gradients = []
46 | for name, module in self.model._modules.items():##resnet50没有.feature这个特征,直接删除用就可以。
47 | x = module(x)
48 | #print('name=',name)
49 | #print('x.size()=',x.size())
50 | if name in self.target_layers:
51 | x.register_hook(self.save_gradient)
52 | outputs += [x]
53 | #print('outputs.size()=',x.size())
54 | #print('len(outputs)',len(outputs))
55 | return outputs, x
56 |
57 | class ModelOutputs():
58 | """ Class for making a forward pass, and getting:
59 | 1. The network output.
60 | 2. Activations from intermeddiate targetted layers.
61 | 3. Gradients from intermeddiate targetted layers. """
62 | def __init__(self, model, target_layers,use_cuda):
63 | self.model = model
64 | self.feature_extractor = FeatureExtractor(self.model, target_layers)
65 | self.cuda = use_cuda
66 | def get_gradients(self):
67 | return self.feature_extractor.gradients
68 |
69 | def __call__(self, x):
70 | target_activations, output = self.feature_extractor(x)
71 | output = output.view(output.size(0), -1)
72 | #print('classfier=',output.size())
73 | if self.cuda:
74 | output = output.cpu()
75 | cnn = []
76 | cnn.append(output)
77 | cnn = torch.stack(cnn, dim=0).transpose(0, 1)
78 | rnn_out, _ = resnet.rnn(cnn)
79 | output = resnet.fc_cnn(rnn_out[:,0,:]).cuda()##这里就是为什么我们多加载一个resnet模型进来的原因,因为后面我们命名的model不包含fc层,但是这里又偏偏要使用。#
80 | else:
81 | cnn = []
82 | cnn.append(output)
83 | cnn = torch.stack(cnn, dim=0).transpose(0, 1)
84 | rnn_out, _ = resnet.rnn(cnn)
85 | output = resnet.fc_cnn(rnn_out[:,0,:])##这里对应use-cuda上更正一些bug,不然用use-cuda的时候会导致类型对不上,这样保证既可以在cpu上运行,gpu上运行也不会出问题.
86 | return target_activations, output
87 |
88 | def preprocess_image(img):
89 | means=[0.485, 0.456, 0.406]
90 | stds=[0.229, 0.224, 0.225]
91 |
92 | preprocessed_img = img.copy()[: , :, ::-1]
93 | for i in range(3):
94 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - means[i]
95 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / stds[i]
96 | preprocessed_img = \
97 | np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1)))
98 | preprocessed_img = torch.from_numpy(preprocessed_img)
99 | preprocessed_img.unsqueeze_(0)
100 | input = preprocessed_img
101 | input.requires_grad = True
102 | return input
103 |
104 | def show_cam_on_image(img, mask,name):
105 | heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
106 | heatmap = np.float32(heatmap) / 255
107 | cam = heatmap + np.float32(img)
108 | cam = cam / np.max(cam)
109 | cv2.imwrite("cam/cam_{}.jpg".format(name), np.uint8(255 * cam))
110 | class GradCam:
111 | def __init__(self, model, target_layer_names, use_cuda):
112 | self.model = model
113 | self.model.eval()
114 | self.cuda = use_cuda
115 | if self.cuda:
116 | self.model = model.cuda()
117 |
118 | self.extractor = ModelOutputs(self.model, target_layer_names, use_cuda)
119 |
120 | def forward(self, input):
121 | return self.model(input)
122 |
123 | def __call__(self, input, index = None):
124 | if self.cuda:
125 | features, output = self.extractor(input.cuda())
126 | else:
127 | features, output = self.extractor(input)
128 |
129 | if index == None:
130 | index = np.argmax(output.cpu().data.numpy())
131 |
132 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
133 | one_hot[0][index] = 1
134 | one_hot = torch.Tensor(torch.from_numpy(one_hot))
135 | one_hot.requires_grad = True
136 | if self.cuda:
137 | one_hot = torch.sum(one_hot.cuda() * output)
138 | else:
139 | one_hot = torch.sum(one_hot * output)
140 |
141 | self.model.zero_grad()##features和classifier不包含,可以重新加回去试一试,会报错不包含这个对象。
142 | #self.model.zero_grad()
143 | one_hot.backward(retain_graph=True)##这里适配我们的torch0.4及以上,我用的1.0也可以完美兼容。(variable改成graph即可)
144 |
145 | grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()
146 | #print('grads_val',grads_val.shape)
147 | target = features[-1]
148 | target = target.cpu().data.numpy()[0, :]
149 |
150 | weights = np.mean(grads_val, axis = (2, 3))[0, :]
151 | #print('weights',weights.shape)
152 | cam = np.zeros(target.shape[1 : ], dtype = np.float32)
153 | #print('cam',cam.shape)
154 | #print('features',features[-1].shape)
155 | #print('target',target.shape)
156 | for i, w in enumerate(weights):
157 | cam += w * target[i, :, :]
158 |
159 | cam = np.maximum(cam, 0)
160 | cam = cv2.resize(cam, (224, 224))
161 | cam = cam - np.min(cam)
162 | cam = cam / np.max(cam)
163 | return cam
164 | class GuidedBackpropReLUModel:
165 | def __init__(self, model, use_cuda):
166 | self.model = model#这里同理,要的是一个完整的网络,不然最后维度会不匹配。
167 | self.model.eval()
168 | self.cuda = use_cuda
169 | if self.cuda:
170 | self.model = model.cuda()
171 | for module in self.model.named_modules():
172 | module[1].register_backward_hook(self.bp_relu)
173 |
174 | def bp_relu(self, module, grad_in, grad_out):
175 | if isinstance(module, nn.ReLU):
176 | return (torch.clamp(grad_in[0], min=0.0),)
177 | def forward(self, input):
178 | return self.model(input)
179 |
180 | def __call__(self, input, index = None):
181 | if self.cuda:
182 | output = self.forward(input.cuda())
183 | else:
184 | output = self.forward(input)
185 | if index == None:
186 | index = np.argmax(output.cpu().data.numpy())
187 | #print(input.grad)
188 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
189 | one_hot[0][index] = 1
190 | one_hot = torch.from_numpy(one_hot)
191 | one_hot.requires_grad = True
192 | if self.cuda:
193 | one_hot = torch.sum(one_hot.cuda() * output)
194 | else:
195 | one_hot = torch.sum(one_hot * output)
196 | #self.model.classifier.zero_grad()
197 | one_hot.backward(retain_graph=True)
198 | output = input.grad.cpu().data.numpy()
199 | output = output[0,:,:,:]
200 |
201 | return output
202 |
203 | def get_args():
204 | parser = argparse.ArgumentParser()
205 | parser.add_argument('--use-cuda', action='store_true', default=False,
206 | help='Use NVIDIA GPU acceleration')
207 | parser.add_argument('--image-path', type=str, default='/data2/guesthome/wenbop/ffdf/test/0/',
208 | help='Input image path')
209 | args = parser.parse_args()
210 | args.use_cuda = args.use_cuda and torch.cuda.is_available()
211 | if args.use_cuda:
212 | print("Using GPU for acceleration")
213 | else:
214 | print("Using CPU for computation")
215 |
216 | return args
217 |
218 | if __name__ == '__main__':
219 | """ python grad_cam.py
220 | 1. Loads an image with opencv.
221 | 2. Preprocesses it for VGG19 and converts to a pytorch variable.
222 | 3. Makes a forward pass to find the category index with the highest score,
223 | and computes intermediate activations.
224 | Makes the visualization. """
225 |
226 | args = get_args()
227 |
228 | model = resnet.cnn
229 | grad_cam = GradCam(model , \
230 | target_layer_names = ["layer4"], use_cuda=args.use_cuda)##这里改成layer4也很简单,我把每层name和size都打印出来了,想看哪层自己直接嵌套就可以了。(最后你会在终端看得到name的)
231 | x=os.walk(args.image_path)
232 | for root, dirs, filename in x:
233 | #print(type(grad_cam))
234 | print(filename)
235 | for s in filename:
236 | image.append(cv2.imread(args.image_path+s,1))
237 | #img = cv2.imread(filename, 1)
238 | for img in image:
239 | img = np.float32(cv2.resize(img, (224, 224))) / 255
240 | input = preprocess_image(img)
241 | input.required_grad = True
242 | print('input.size()=',input.size())
243 | # If None, returns the map for the highest scoring category.
244 | # Otherwise, targets the requested index.
245 | target_index =None
246 |
247 | mask = grad_cam(input, target_index)
248 | i=i+1
249 | show_cam_on_image(img, mask,i)
250 |
251 | gb_model = GuidedBackpropReLUModel(model = resnet, use_cuda=args.use_cuda)
252 | gb = gb_model(input, index=target_index)
253 | if not os.path.exists('gb'):
254 | os.mkdir('gb')
255 | if not os.path.exists('camgb'):
256 | os.mkdir('camgb')
257 | utils.save_image(torch.from_numpy(gb), 'gb/gb_{}.jpg'.format(i))
258 | cam_mask = np.zeros(gb.shape)
259 | for j in range(0, gb.shape[0]):
260 | cam_mask[j, :, :] = mask
261 | cam_gb = np.multiply(cam_mask, gb)
262 | utils.save_image(torch.from_numpy(cam_gb), 'camgb/cam_gb_{}.jpg'.format(i))
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | dataset_params = {
2 | 'shuffle': True,
3 | 'num_workers': 4,
4 | 'pin_memory': False
5 | }
6 |
7 | net_params = {
8 | 'use_gru': True,
9 | 'bi_branch': True,
10 | 'dct': False,
11 | 'inputgate': False
12 | }
13 |
14 | resnet_3d_params = {
15 | 'num_classes': 2,
16 | 'model_depth': 50,
17 | 'shortcut_type': 'B',
18 | 'sample_size': img_h,
19 | 'sample_duration': 30
20 | }
21 |
22 | models = {
23 | 1: 'baseline',
24 | 2: 'cRNN',
25 | 3: 'end2end',
26 | 4: 'xception',
27 | 5: 'fwa',
28 | 6: 'cnn',
29 | 7: 'res50',
30 | 8: 'res101',
31 | 9: 'res152'
32 | }
33 |
34 | losses = {
35 | 0: 'CE',
36 | 1: 'AUC',
37 | 2: 'focal'
38 | }
39 |
40 |
41 | gamma = 0.15
42 |
43 | model_type = models.get(1)
44 | loss_type = losses.get(1)
45 | learning_rate = 1e-4
46 | epoches = 20
47 | log_interval = 2 # 打印间隔,默认每2个batch_size打印一次
48 | save_interval = 1 # 模型保存间隔,默认每个epoch保存一次
49 |
--------------------------------------------------------------------------------
/utils/drawpics.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from sklearn.metrics import auc
3 | from matplotlib.pyplot import MultipleLocator
4 |
5 | from utils.dataloader import *
6 |
7 |
8 | def draw_compare():
9 | plt.figure()
10 | # plt.title('', fontsize=20)
11 | plt.xlabel('positive to negative')
12 | plt.ylabel('Frame level AUC')
13 |
14 | plt.plot(['1:10', '1:20', '1:30'], [0.72, 0.75, 0.57], label='Meso4', marker='o')
15 | plt.plot(['1:10', '1:20', '1:30'], [0.60, 0.64, 0.50], label='Xception', marker='s')
16 | plt.plot(['1:10', '1:20', '1:30'], [0.82, 0.79, 0.57], label='DSP-FWA', marker='^')
17 | plt.plot(['1:10', '1:20', '1:30'], [0.63, 0.67, 0.56], label='Capsule', marker='*')
18 | plt.plot(['1:10', '1:20', '1:30'], [0.91, 0.92, 0.78], label='Ours', marker='D')
19 |
20 | plt.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left",
21 | mode="expand", borderaxespad=0, ncol=5)
22 | plt.savefig('compare.pdf')
23 |
24 |
25 | def draw_AUC():
26 | f_fpr = np.load('/home/asus/Code/pvc/m/bs/f_fpr.npy')
27 | f_tpr = np.load('/home/asus/Code/pvc/m/bs/f_tpr.npy')
28 | f_roc_auc = auc(f_fpr, f_tpr)
29 | xcp_f_fpr = np.load('/home/asus/Code/pvc/m/xcp/f_fpr.npy')
30 | xcp_f_tpr = np.load('/home/asus/Code/pvc/m/xcp/f_tpr.npy')
31 | xcp_roc_auc = auc(xcp_f_fpr, xcp_f_tpr)
32 | cap_f_fpr = np.load('/home/asus/Code/pvc/m/cap/f_fpr.npy')
33 | cap_f_tpr = np.load('/home/asus/Code/pvc/m/cap/f_tpr.npy')
34 | cap_roc_auc = auc(cap_f_fpr, cap_f_tpr)
35 | ms4_f_fpr = np.load('/home/asus/Code/pvc/m/ms4/f_fpr.npy')
36 | ms4_f_tpr = np.load('/home/asus/Code/pvc/m/ms4/f_tpr.npy')
37 | ms4_roc_auc = auc(ms4_f_fpr, ms4_f_tpr)
38 | msi_f_fpr = np.load('/home/asus/Code/pvc/m/msi/f_fpr.npy')
39 | msi_f_tpr = np.load('/home/asus/Code/pvc/m/msi/f_tpr.npy')
40 | msi_roc_auc = auc(msi_f_fpr, msi_f_tpr)
41 | plt.figure()
42 | lw = 2
43 | plt.plot(f_fpr, f_tpr,
44 | lw=lw, label='Ours ROC curve (area = %0.2f)' % f_roc_auc)
45 | plt.plot(xcp_f_fpr, xcp_f_tpr,
46 | lw=lw, label='Xception ROC curve (area = %0.2f)' % xcp_roc_auc)
47 | plt.plot(cap_f_fpr, cap_f_tpr,
48 | lw=lw, label='Capsule ROC curve (area = %0.2f)' % cap_roc_auc)
49 | plt.plot(ms4_f_fpr, ms4_f_tpr,
50 | lw=lw, label='Meso4 ROC curve (area = %0.2f)' % ms4_roc_auc)
51 | plt.plot(msi_f_fpr, msi_f_tpr,
52 | lw=lw, label='MesoInception4 ROC curve (area = %0.2f)' % msi_roc_auc)
53 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
54 | plt.xlim([0.0, 1.0])
55 | plt.ylim([0.0, 1.05])
56 | plt.xlabel('False Positive Rate')
57 | plt.ylabel('True Positive Rate')
58 | plt.legend(loc="lower right")
59 | plt.savefig('df_frame.pdf')
60 |
61 |
62 | def draw_WMW():
63 | # 0.0, 0.2, 0.4, 0.6, 0.8, 1.0
64 | WMW_auc_frame = [0.9032, 0.8772, 0.9383, 0.9293, 0.826, 0.8079]
65 | WMW_acc_frame = [0.957, 0.957, 0.957, 0.957, 0.957, 0.957]
66 | WMW_f1_frame = [0.978, 0.978, 0.978, 0.978, 0.978, 0.978]
67 | WMW_recall_frame = [1, 1, 1, 1, 1, 1]
68 | WMW_auc_video = [0.908, 0.894, 0.9473, 0.9544, 0.7472, 0.8181]
69 | WMW_acc_video = [0.957, 0.957, 0.957, 0.957, 0.957, 0.957]
70 | WMW_f1_video = [0.978, 0.978, 0.978, 0.978, 0.978, 0.978]
71 | WMW_recall_video = [1, 1, 1, 1, 1, 1]
72 |
73 | x = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
74 |
75 | plt.figure()
76 | plt.plot(np.array(x), np.array(WMW_auc_frame), label='AUC')
77 | # plt.plot(np.array(x), np.array(WMW_auc_video), label='Video level AUC')
78 | plt.plot(np.array(x), np.array(WMW_acc_frame), label='ACC')
79 | plt.plot(np.array(x), np.array(WMW_f1_frame), label='F1')
80 | plt.plot(np.array(x), np.array(WMW_recall_frame), label='recall')
81 | plt.xlim(0.0, 1)
82 | plt.legend(loc="lower center")
83 | plt.xlabel('Margin parameter (gamma) of WMW Loss')
84 | plt.ylabel('Metrics score')
85 | plt.savefig('frame_score.pdf')
86 | plt.show()
87 |
88 | plt.figure()
89 | plt.plot(np.array(x), np.array(WMW_auc_video), label='AUC')
90 | # plt.plot(np.array(x), np.array(WMW_auc_video), label='Video level AUC')
91 | plt.plot(np.array(x), np.array(WMW_acc_video), label='ACC')
92 | plt.plot(np.array(x), np.array(WMW_f1_video), label='F1')
93 | plt.plot(np.array(x), np.array(WMW_recall_video), label='recall')
94 | plt.xlim(0.0, 1)
95 | plt.legend(loc="lower center")
96 | plt.xlabel('Margin parameter (gamma) of WMW Loss')
97 | plt.ylabel('Metrics score')
98 | plt.savefig('video_score.pdf')
99 | plt.show()
100 |
101 |
102 | def draw_auc_compare():
103 | name = ['Celeb-30', 'Celeb-20', 'Celeb-10']
104 | our_list = [0.74, 0.95, 0.94]
105 | w_focal = [0.72, 0.95, 0.91]
106 | wo_auc = [0.70, 0.90, 0.87]
107 |
108 | x = np.arange(len(name))
109 | width = 0.25
110 |
111 | plt.bar(x, our_list, width=width, label='Ours')
112 | plt.bar(x + width, w_focal, width=width, label='Ours with FL', tick_label=name)
113 | plt.bar(x + 2 * width, wo_auc, width=width, label='Ours with BCE')
114 |
115 | # x_major_locator = MultipleLocator(1)
116 | # 把x轴的刻度间隔设置为1,并存在变量里
117 | y_major_locator = MultipleLocator(0.1)
118 | ax = plt.gca()
119 | # ax为两条坐标轴的实例
120 | # ax.xaxis.set_major_locator(x_major_locator)
121 | # 把x轴的主刻度设置为1的倍数
122 | ax.yaxis.set_major_locator(y_major_locator)
123 | # 显示在图形上的值
124 | # for a, b in zip(x, our_list):
125 | # plt.text(a, b + 0.1, b, ha='center', va='bottom')
126 | # for a, b in zip(x, w_focal):
127 | # plt.text(a + width, b + 0.1, b, ha='center', va='bottom')
128 | # for a, b in zip(x, wo_auc):
129 | # plt.text(a + 2 * width, b + 0.1, b, ha='center', va='bottom')
130 |
131 | plt.xticks()
132 | plt.ylim([0.5, 1.0])
133 | plt.legend(loc="upper left") # 防止label和图像重合显示不出来
134 | # plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
135 | plt.ylabel('AUC score')
136 | # plt.xlabel('line')
137 | # plt.rcParams['savefig.dpi'] = 300 # 图片像素
138 | # plt.rcParams['figure.dpi'] = 300 # 分辨率
139 | # plt.rcParams['figure.figsize'] = (15.0, 8.0) # 尺寸
140 | # plt.title("title")
141 | plt.savefig('w_wo_auc.pdf')
142 | plt.show()
143 |
144 | # x = list(range(len(our_list)))
145 | # total_width, n = 0.8, 3
146 | # width = total_width / n
147 | #
148 | # plt.bar(x, our_list, width=width, label='Our')
149 | # for i in range(len(x)):
150 | # x[i] = x[i] + width
151 | # plt.bar(x, w_focal, width=width, label='Our w Focal loss')
152 | # plt.bar(x, wo_auc, width=width, label='Our w/o AUC loss')
153 | # plt.xticks(np.array(x) - width / 3, name_list)
154 | # plt.legend()
155 | # plt.savefig('w_wo_auc.pdf')
156 | # plt.show()
157 |
158 | # x = 3
159 | # total_width, n = 0.8, 3 # 有多少个类型,只需更改n即可
160 | # width = total_width / n
161 | # x = x - (total_width - width) / 2
162 | #
163 | # plt.bar(x, our_list, width=width, label='Ours')
164 | # plt.bar(x + width, w_focal, width=width, label='Ours with Focal loss')
165 | # plt.bar(x + 2 * width, wo_auc, width=width, label='Ours with BCE ')
166 | #
167 | # plt.xticks()
168 | # plt.legend(loc="upper left") # 防止label和图像重合显示不出来
169 | # # plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
170 | # plt.ylabel('AUC score')
171 | # # plt.xlabel('line')
172 | # # plt.rcParams['savefig.dpi'] = 300 # 图片像素
173 | # # plt.rcParams['figure.dpi'] = 300 # 分辨率
174 | # # plt.rcParams['figure.figsize'] = (15.0, 8.0) # 尺寸
175 | # # plt.title("title")
176 | # plt.savefig('w_wo_auc.pdf')
177 | # plt.show()
178 |
179 |
180 | # draw_WMW()
181 | # draw_auc_compare()
182 | # draw_AUC()
183 | draw_compare()
184 |
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from sklearn.metrics import accuracy_score
4 | from tqdm import tqdm
5 | from PIL import Image
6 | import pandas
7 | import os
8 | import argparse
9 | import cv2
10 |
11 | from dataloader import Dataset
12 | from models.model import CNNEncoder, RNNDecoder
13 | import config
14 |
15 | def load_imgs_from_video(path: str)->list:
16 | """Extract images from video.
17 |
18 | Args:
19 | path(str): The path of video.
20 |
21 | Returns:
22 | A list of PIL Image.
23 | """
24 | video_fd = cv2.VideoCapture(path)
25 | video_fd.set(16, True)
26 | # flag 16: 'CV_CAP_PROP_CONVERT_RGB'
27 | # indicating the images should be converted to RGB.
28 |
29 | if not video_fd.isOpened():
30 | raise ValueError('Invalid path! which is: {}'.format(path))
31 |
32 | images = [] # type: list[Image]
33 |
34 | success, frame = video_fd.read()
35 | while success:
36 | images.append(Image.fromarray(frame))
37 | success, frame = video_fd.read()
38 |
39 | return images
40 |
41 | def _eval(checkpoint: str, video_path: str, labels=[])->list:
42 | """Inference the model and return the labels.
43 |
44 | Args:
45 | checkpoint(str): The checkpoint where the model restore from.
46 | path(str): The path of videos.
47 | labels(list): Labels of videos.
48 |
49 | Returns:
50 | A list of labels of the videos.
51 | """
52 | if not os.path.exists(video_path):
53 | raise ValueError('Invalid path! which is: {}'.format(video_path))
54 |
55 | print('Loading model from {}'.format(checkpoint))
56 | use_cuda = torch.cuda.is_available()
57 | device = torch.device('cuda' if use_cuda else 'cpu')
58 |
59 | # Build model
60 | model = nn.Sequential(
61 | CNNEncoder(**config.cnn_encoder_params),
62 | RNNDecoder(**config.rnn_decoder_params)
63 | )
64 | model.to(device)
65 | model.eval()
66 |
67 | # Load model
68 | ckpt = torch.load(checkpoint)
69 | model.load_state_dict(ckpt['model_state_dict'])
70 | print('Model has been loaded from {}'.format(checkpoint))
71 |
72 | label_map = [-1] * config.rnn_decoder_params['num_classes']
73 | # load label map
74 | if 'label_map' in ckpt:
75 | label_map = ckpt['label_map']
76 |
77 | # Do inference
78 | pred_labels = []
79 | video_names = os.listdir(video_path)
80 | with torch.no_grad():
81 | for video in tqdm(video_names, desc='Inferencing'):
82 | # read images from video
83 | images = load_imgs_from_video(os.path.join(video_path, video))
84 | # apply transform
85 | images = [Dataset.transform(None, img) for img in images]
86 | # stack to tensor, batch size = 1
87 | images = torch.stack(images, dim=0).unsqueeze(0)
88 | # do inference
89 | images = images.to(device)
90 | pred_y = model(images) # type: torch.Tensor
91 | pred_y = pred_y.argmax(dim=1).cpu().numpy().tolist()
92 | pred_labels.append([video, pred_y[0], label_map[pred_y[0]]])
93 | print(pred_labels[-1])
94 |
95 | if len(labels) > 0:
96 | acc = accuracy_score(pred_labels, labels)
97 | print('Accuracy: %0.2f' % acc)
98 |
99 | # Save results
100 | pandas.DataFrame(pred_labels).to_csv('result.csv', index=False)
101 | print('Results has been saved to {}'.format('result.csv'))
102 |
103 | return pred_labels
104 |
105 | def parse_args():
106 | parser = argparse.ArgumentParser(usage='python3 eval.py -i path/to/videos -r path/to/checkpoint')
107 | parser.add_argument('-i', '--video_path', help='path to videos')
108 | parser.add_argument('-r', '--checkpoint', help='path to the checkpoint')
109 | args = parser.parse_args()
110 | return args
111 |
112 | if __name__ == "__main__":
113 | args = parse_args()
114 | _eval(args.checkpoint, args.video_path)
115 |
--------------------------------------------------------------------------------
/utils/ff.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """ Downloads FaceForensics++ and Deep Fake Detection public data release
3 | Example usage:
4 | see -h or https://github.com/ondyari/FaceForensics
5 | """
6 | # -*- coding: utf-8 -*-
7 | import argparse
8 | import os
9 | import urllib
10 | import urllib.request
11 | import tempfile
12 | import time
13 | import sys
14 | import json
15 | import random
16 | from tqdm import tqdm
17 | from os.path import join
18 |
19 |
20 | # URLs and filenames
21 | FILELIST_URL = 'misc/filelist.json'
22 | DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json'
23 | DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',]
24 |
25 | # Parameters
26 | DATASETS = {
27 | 'original_youtube_videos': 'misc/downloaded_youtube_videos.zip',
28 | 'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip',
29 | 'original': 'original_sequences/youtube',
30 | 'DeepFakeDetection_original': 'original_sequences/actors',
31 | 'Deepfakes': 'manipulated_sequences/Deepfakes',
32 | 'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection',
33 | 'Face2Face': 'manipulated_sequences/Face2Face',
34 | 'FaceShifter': 'manipulated_sequences/FaceShifter',
35 | 'FaceSwap': 'manipulated_sequences/FaceSwap',
36 | 'NeuralTextures': 'manipulated_sequences/NeuralTextures'
37 | }
38 | ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes',
39 | 'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap',
40 | 'NeuralTextures']
41 | COMPRESSION = ['raw', 'c23', 'c40']
42 | TYPE = ['videos', 'masks', 'models']
43 | SERVERS = ['EU', 'EU2', 'CA']
44 |
45 |
46 | def parse_args():
47 | parser = argparse.ArgumentParser(
48 | description='Downloads FaceForensics v2 public data release.',
49 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
50 | )
51 | parser.add_argument('-o', '--output_path', type=str, default='./', help='Output directory.')
52 | parser.add_argument('-d', '--dataset', type=str, default='Deepfakes',
53 | help='Which dataset to download, either pristine or '
54 | 'manipulated data or the downloaded youtube '
55 | 'videos.',
56 | choices=list(DATASETS.keys()) + ['all']
57 | )
58 | parser.add_argument('-c', '--compression', type=str, default='c23',
59 | help='Which compression degree. All videos '
60 | 'have been generated with h264 with a varying '
61 | 'codec. Raw (c0) videos are lossless compressed.',
62 | choices=COMPRESSION
63 | )
64 | parser.add_argument('-t', '--type', type=str, default='videos',
65 | help='Which file type, i.e. videos, masks, for our '
66 | 'manipulation methods, models, for Deepfakes.',
67 | choices=TYPE
68 | )
69 | parser.add_argument('-n', '--num_videos', type=int, default=None,
70 | help='Select a number of videos number to '
71 | "download if you don't want to download the full"
72 | ' dataset.')
73 | parser.add_argument('--server', type=str, default='EU',
74 | help='Server to download the data from. If you '
75 | 'encounter a slow download speed, consider '
76 | 'changing the server.',
77 | choices=SERVERS
78 | )
79 | args = parser.parse_args()
80 |
81 | # URLs
82 | server = args.server
83 | if server == 'EU':
84 | server_url = 'http://canis.vc.in.tum.de:8100/'
85 | elif server == 'EU2':
86 | server_url = 'http://kaldir.vc.in.tum.de/faceforensics/'
87 | elif server == 'CA':
88 | server_url = 'http://falas.cmpt.sfu.ca:8100/'
89 | else:
90 | raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS)))
91 | args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf'
92 | args.base_url = server_url + 'v3/'
93 | args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \
94 | 'Deepfakes/models/'
95 |
96 | return args
97 |
98 |
99 | def download_files(filenames, base_url, output_path, report_progress=True):
100 | os.makedirs(output_path, exist_ok=True)
101 | if report_progress:
102 | filenames = tqdm(filenames)
103 | for filename in filenames:
104 | download_file(base_url + filename, join(output_path, filename))
105 |
106 |
107 | def reporthook(count, block_size, total_size):
108 | global start_time
109 | if count == 0:
110 | start_time = time.time()
111 | return
112 | duration = time.time() - start_time
113 | progress_size = int(count * block_size)
114 | speed = int(progress_size / (1024 * duration))
115 | percent = int(count * block_size * 100 / total_size)
116 | sys.stdout.write("\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" %
117 | (percent, progress_size / (1024 * 1024), speed, duration))
118 | sys.stdout.flush()
119 |
120 |
121 | def download_file(url, out_file, report_progress=False):
122 | out_dir = os.path.dirname(out_file)
123 | if not os.path.isfile(out_file):
124 | fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
125 | f = os.fdopen(fh, 'w')
126 | f.close()
127 | if report_progress:
128 | urllib.request.urlretrieve(url, out_file_tmp,
129 | reporthook=reporthook)
130 | else:
131 | urllib.request.urlretrieve(url, out_file_tmp)
132 | os.rename(out_file_tmp, out_file)
133 | else:
134 | tqdm.write('WARNING: skipping download of existing file ' + out_file)
135 |
136 |
137 | def main(args):
138 | # TOS
139 | print('By pressing any key to continue you confirm that you have agreed '\
140 | 'to the FaceForensics terms of use as described at:')
141 | print(args.tos_url)
142 | print('***')
143 | print('Press any key to continue, or CTRL-C to exit.')
144 | _ = input('')
145 |
146 | # Extract arguments
147 | c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS
148 | c_type = args.type
149 | c_compression = args.compression
150 | num_videos = args.num_videos
151 | output_path = args.output_path
152 | os.makedirs(output_path, exist_ok=True)
153 |
154 | # Check for special dataset cases
155 | for dataset in c_datasets:
156 | dataset_path = DATASETS[dataset]
157 | # Special cases
158 | if 'original_youtube_videos' in dataset:
159 | # Here we download the original youtube videos zip file
160 | print('Downloading original youtube videos.')
161 | if not 'info' in dataset_path:
162 | print('Please be patient, this may take a while (~40gb)')
163 | suffix = ''
164 | else:
165 | suffix = 'info'
166 | download_file(args.base_url + '/' + dataset_path,
167 | out_file=join(output_path,
168 | 'downloaded_videos{}.zip'.format(
169 | suffix)),
170 | report_progress=True)
171 | return
172 |
173 | # Else: regular datasets
174 | print('Downloading {} of dataset "{}"'.format(
175 | c_type, dataset_path
176 | ))
177 |
178 | # Get filelists and video lenghts list from server
179 | if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path:
180 | filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' +
181 | DEEPFEAKES_DETECTION_URL).read().decode("utf-8"))
182 | if 'actors' in dataset_path:
183 | filelist = filepaths['actors']
184 | else:
185 | filelist = filepaths['DeepFakesDetection']
186 | elif 'original' in dataset_path:
187 | # Load filelist from server
188 | file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
189 | FILELIST_URL).read().decode("utf-8"))
190 | filelist = []
191 | for pair in file_pairs:
192 | filelist += pair
193 | else:
194 | # Load filelist from server
195 | file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
196 | FILELIST_URL).read().decode("utf-8"))
197 | # Get filelist
198 | filelist = []
199 | for pair in file_pairs:
200 | filelist.append('_'.join(pair))
201 | if c_type != 'models':
202 | filelist.append('_'.join(pair[::-1]))
203 | # Maybe limit number of videos for download
204 | if num_videos is not None and num_videos > 0:
205 | print('Downloading the first {} videos'.format(num_videos))
206 | filelist = filelist[:num_videos]
207 |
208 | # Server and local paths
209 | dataset_videos_url = args.base_url + '{}/{}/{}/'.format(
210 | dataset_path, c_compression, c_type)
211 | dataset_mask_url = args.base_url + '{}/{}/videos/'.format(
212 | dataset_path, 'masks', c_type)
213 |
214 | if c_type == 'videos':
215 | dataset_output_path = join(output_path, dataset_path, c_compression,
216 | c_type)
217 | print('Output path: {}'.format(dataset_output_path))
218 | filelist = [filename + '.mp4' for filename in filelist]
219 | download_files(filelist, dataset_videos_url, dataset_output_path)
220 | elif c_type == 'masks':
221 | dataset_output_path = join(output_path, dataset_path, c_type,
222 | 'videos')
223 | print('Output path: {}'.format(dataset_output_path))
224 | if 'original' in dataset:
225 | if args.dataset != 'all':
226 | print('Only videos available for original data. Aborting.')
227 | return
228 | else:
229 | print('Only videos available for original data. '
230 | 'Skipping original.\n')
231 | continue
232 | if 'FaceShifter' in dataset:
233 | print('Masks not available for FaceShifter. Aborting.')
234 | return
235 | filelist = [filename + '.mp4' for filename in filelist]
236 | download_files(filelist, dataset_mask_url, dataset_output_path)
237 |
238 | # Else: models for deepfakes
239 | else:
240 | if dataset != 'Deepfakes' and c_type == 'models':
241 | print('Models only available for Deepfakes. Aborting')
242 | return
243 | dataset_output_path = join(output_path, dataset_path, c_type)
244 | print('Output path: {}'.format(dataset_output_path))
245 |
246 | # Get Deepfakes models
247 | for folder in tqdm(filelist):
248 | folder_filelist = DEEPFAKES_MODEL_NAMES
249 |
250 | # Folder paths
251 | folder_base_url = args.deepfakes_model_url + folder + '/'
252 | folder_dataset_output_path = join(dataset_output_path,
253 | folder)
254 | download_files(folder_filelist, folder_base_url,
255 | folder_dataset_output_path,
256 | report_progress=False) # already done
257 |
258 |
259 | if __name__ == "__main__":
260 | args = parse_args()
261 | main(args)
262 |
--------------------------------------------------------------------------------
/utils/focalloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | class BCEFocalLoss(torch.nn.Module):
8 |
9 | def __init__(self, gamma=2, alpha=0.6, reduction='elementwise_mean'):
10 | super().__init__()
11 | self.gamma = gamma
12 | self.alpha = alpha
13 | self.reduction = reduction
14 |
15 | def forward(self, _input, target):
16 | pt = torch.sigmoid(_input)
17 | # pt = _input
18 | alpha = self.alpha
19 | loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
20 | (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
21 | if self.reduction == 'elementwise_mean':
22 | loss = torch.mean(loss)
23 | elif self.reduction == 'sum':
24 | loss = torch.sum(loss)
25 | return loss
26 |
27 |
28 | class FocalLoss(nn.Module):
29 | def __init__(self, gamma=0, alpha=None, size_average=True):
30 | super(FocalLoss, self).__init__()
31 | self.gamma = gamma
32 | self.alpha = alpha
33 | if isinstance(alpha, (float, int, long)): self.alpha = torch.Tensor([alpha, 1 - alpha])
34 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
35 | self.size_average = size_average
36 |
37 | def forward(self, input, target):
38 | if input.dim() > 2:
39 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W
40 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
41 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
42 | target = target.view(-1, 1)
43 |
44 | logpt = F.log_softmax(input)
45 | logpt = logpt.gather(1, target)
46 | logpt = logpt.view(-1)
47 | pt = Variable(logpt.data.exp())
48 |
49 | if self.alpha is not None:
50 | if self.alpha.type() != input.data.type():
51 | self.alpha = self.alpha.type_as(input.data)
52 | at = self.alpha.gather(0, target.data.view(-1))
53 | logpt = logpt * Variable(at)
54 |
55 | loss = -1 * (1 - pt) ** self.gamma * logpt
56 | if self.size_average:
57 | return loss.mean()
58 | else:
59 | return loss.sum()
60 |
--------------------------------------------------------------------------------
/utils/gradcam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.utils.data import DataLoader
5 | from torch.autograd import Function
6 |
7 | import numpy as np
8 | import cv2
9 |
10 | from models.model import Baseline
11 |
12 | import os
13 | import argparse
14 | import config
15 |
16 |
17 | class FeatureExtractor():
18 | """ Class for extracting activations and
19 | registering gradients from targetted intermediate layers """
20 |
21 | def __init__(self, model, target_layers):
22 | self.model = model
23 | self.target_layers = target_layers
24 | self.gradients = []
25 |
26 | def save_gradient(self, grad):
27 | self.gradients.append(grad)
28 |
29 | def __call__(self, x):
30 | outputs = []
31 | self.gradients = []
32 | for name, module in self.model._modules.items():
33 | x = module(x)
34 | if name in self.target_layers:
35 | x.register_hook(self.save_gradient)
36 | outputs += [x]
37 | return outputs, x
38 |
39 |
40 | class ModelOutputs():
41 | """ Class for making a forward pass, and getting:
42 | 1. The network output.
43 | 2. Activations from intermeddiate targetted layers.
44 | 3. Gradients from intermeddiate targetted layers. """
45 |
46 | def __init__(self, model, feature_module, target_layers):
47 | self.model = model
48 | self.feature_module = feature_module
49 | self.feature_extractor = FeatureExtractor(self.feature_module, target_layers)
50 |
51 | def get_gradients(self):
52 | return self.feature_extractor.gradients
53 |
54 | def __call__(self, x):
55 | target_activations = []
56 | for name, module in self.model._modules.items():
57 | if module == self.feature_module:
58 | target_activations, x = self.feature_extractor(x)
59 | elif "avgpool" in name.lower():
60 | x = module(x)
61 | x = x.view(x.size(0), -1)
62 | else:
63 | x = module(x)
64 |
65 | return target_activations, x
66 |
67 |
68 | def preprocess_image(img):
69 | means = [0.485, 0.456, 0.406]
70 | stds = [0.229, 0.224, 0.225]
71 |
72 | preprocessed_img = img.copy()[:, :, ::-1]
73 | for i in range(3):
74 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - means[i]
75 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / stds[i]
76 | preprocessed_img = \
77 | np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1)))
78 | preprocessed_img = torch.from_numpy(preprocessed_img)
79 | preprocessed_img.unsqueeze_(0)
80 | input = preprocessed_img.requires_grad_(True)
81 | return input
82 |
83 |
84 | def show_cam_on_image(img, mask, path):
85 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
86 | heatmap = np.float32(heatmap) / 255
87 | cam = heatmap + np.float32(img)
88 | cam = cam / np.max(cam)
89 | cv2.imwrite(path + "_cam.jpg", np.uint8(255 * cam))
90 |
91 |
92 | class GradCam:
93 | def __init__(self, model, feature_module, target_layer_names, use_cuda):
94 | self.model = model
95 | self.feature_module = feature_module
96 | self.model.eval()
97 | self.cuda = use_cuda
98 | if self.cuda:
99 | self.model = model.cuda()
100 |
101 | self.extractor = ModelOutputs(self.model, self.feature_module, target_layer_names)
102 |
103 | def forward(self, input):
104 | return self.model(input)
105 |
106 | def __call__(self, input, index=None):
107 | if self.cuda:
108 | features, output = self.extractor(input.cuda())
109 | else:
110 | features, output = self.extractor(input)
111 |
112 | if index == None:
113 | index = np.argmax(output.cpu().data.numpy())
114 |
115 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
116 | one_hot[0][index] = 1
117 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
118 | if self.cuda:
119 | one_hot = torch.sum(one_hot.cuda() * output)
120 | else:
121 | one_hot = torch.sum(one_hot * output)
122 |
123 | self.feature_module.zero_grad()
124 | self.model.zero_grad()
125 | one_hot.backward(retain_graph=True)
126 |
127 | grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()
128 |
129 | target = features[-1]
130 | target = target.cpu().data.numpy()[0, :]
131 |
132 | weights = np.mean(grads_val, axis=(2, 3))[0, :]
133 | cam = np.zeros(target.shape[1:], dtype=np.float32)
134 |
135 | for i, w in enumerate(weights):
136 | cam += w * target[i, :, :]
137 |
138 | cam = np.maximum(cam, 0)
139 | cam = cv2.resize(cam, input.shape[2:])
140 | cam = cam - np.min(cam)
141 | cam = cam / np.max(cam)
142 | return cam
143 |
144 |
145 | class GuidedBackpropReLU(Function):
146 |
147 | @staticmethod
148 | def forward(self, input):
149 | positive_mask = (input > 0).type_as(input)
150 | output = torch.addcmul(torch.zeros(input.size()).type_as(input), input, positive_mask)
151 | self.save_for_backward(input, output)
152 | return output
153 |
154 | @staticmethod
155 | def backward(self, grad_output):
156 | input, output = self.saved_tensors
157 | grad_input = None
158 |
159 | positive_mask_1 = (input > 0).type_as(grad_output)
160 | positive_mask_2 = (grad_output > 0).type_as(grad_output)
161 | grad_input = torch.addcmul(torch.zeros(input.size()).type_as(input),
162 | torch.addcmul(torch.zeros(input.size()).type_as(input), grad_output,
163 | positive_mask_1), positive_mask_2)
164 |
165 | return grad_input
166 |
167 |
168 | class GuidedBackpropReLUModel:
169 | def __init__(self, model, use_cuda):
170 | self.model = model
171 | self.model.eval()
172 | self.cuda = use_cuda
173 | if self.cuda:
174 | self.model = model.cuda()
175 |
176 | def recursive_relu_apply(module_top):
177 | for idx, module in module_top._modules.items():
178 | recursive_relu_apply(module)
179 | if module.__class__.__name__ == 'ReLU':
180 | module_top._modules[idx] = GuidedBackpropReLU.apply
181 |
182 | # replace ReLU with GuidedBackpropReLU
183 | recursive_relu_apply(self.model)
184 |
185 | def forward(self, input):
186 | return self.model(input)
187 |
188 | def __call__(self, input, index=None):
189 | if self.cuda:
190 | output = self.forward(input.cuda())
191 | else:
192 | output = self.forward(input)
193 |
194 | if index == None:
195 | index = np.argmax(output.cpu().data.numpy())
196 |
197 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
198 | one_hot[0][index] = 1
199 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
200 | if self.cuda:
201 | one_hot = torch.sum(one_hot.cuda() * output)
202 | else:
203 | one_hot = torch.sum(one_hot * output)
204 |
205 | # self.model.features.zero_grad()
206 | # self.model.classifier.zero_grad()
207 | one_hot.backward(retain_graph=True)
208 |
209 | output = input.grad.cpu().data.numpy()
210 | output = output[0, :, :, :]
211 |
212 | return output
213 |
214 |
215 | def deprocess_image(img):
216 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
217 | img = img - np.mean(img)
218 | img = img / (np.std(img) + 1e-5)
219 | img = img * 0.1
220 | img = img + 0.5
221 | img = np.clip(img, 0, 1)
222 | return np.uint8(img * 255)
223 |
224 |
225 | def parse_args():
226 | parser = argparse.ArgumentParser(usage='python3 main.py -i path/to/data -r path/to/checkpoint')
227 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint',
228 | default='/data2/guesthome/wenbop/modules/new-bi-model_type-baseline_gru_ep-17.pth')
229 | # parser.add_argument('-g', '--gpu', help='visible gpu ids', default='4,5,7')
230 | parser.add_argument('-i', '--image-path', type=str, default='/data2/guesthome/wenbop/ffdf/test/0/',
231 | help='Input image path')
232 | parser.add_argument('-g', '--gpu', help='visible gpu ids', default='0,1,2,3')
233 | args = parser.parse_args()
234 | return args
235 |
236 |
237 | if __name__ == "__main__":
238 | args = parse_args()
239 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
240 | use_cuda = torch.cuda.is_available()
241 | device = torch.device('cuda' if use_cuda else 'cpu')
242 | model = Baseline(**config.net_params)
243 | device_count = torch.cuda.device_count()
244 | if device_count > 1:
245 | print('使用{}个GPU训练'.format(device_count))
246 | model = nn.DataParallel(model)
247 | model.to(device)
248 | ckpt = {}
249 | # 从断点继续训练
250 | if args.restore_from is not None:
251 | ckpt = torch.load(args.restore_from)
252 | # model.load_state_dict(ckpt['net'])
253 | model.load_state_dict(ckpt['model_state_dict'])
254 | print('Model is loaded from %s' % (args.restore_from))
255 |
256 | model = model.module.cnn
257 | grad_cam = GradCam(model=model, feature_module=model[7], target_layer_names=["2"], use_cuda=True)
258 |
259 | img = cv2.imread(args.image_path, 1)
260 | img = np.float32(cv2.resize(img, (224, 224))) / 255
261 | input = preprocess_image(img)
262 |
263 | # If None, returns the map for the highest scoring category.
264 | # Otherwise, targets the requested index.
265 | target_index = 0
266 | mask = grad_cam(input, target_index)
267 |
268 | show_cam_on_image(img, mask, args.image_path.split('/')[-1])
269 |
270 | gb_model = GuidedBackpropReLUModel(model=model, use_cuda=True)
271 |
272 | gb = gb_model(input, index=target_index)
273 | gb = gb.transpose((1, 2, 0))
274 | cam_mask = cv2.merge([mask, mask, mask])
275 | cam_gb = deprocess_image(cam_mask * gb)
276 | gb = deprocess_image(gb)
277 |
278 | cv2.imwrite('./gram/' + args.image_path.split('/')[-1] + '_gb.jpg', gb)
279 | cv2.imwrite('./gram/' + args.image_path.split('/')[-1] + '_cam_gb.jpg', cam_gb)
280 |
--------------------------------------------------------------------------------
/utils/mmod_human_face_detector.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/utils/mmod_human_face_detector.dat
--------------------------------------------------------------------------------
/utils/test.json:
--------------------------------------------------------------------------------
1 | [
2 | [
3 | "953",
4 | "974"
5 | ],
6 | [
7 | "012",
8 | "026"
9 | ],
10 | [
11 | "078",
12 | "955"
13 | ],
14 | [
15 | "623",
16 | "630"
17 | ],
18 | [
19 | "919",
20 | "015"
21 | ],
22 | [
23 | "367",
24 | "371"
25 | ],
26 | [
27 | "847",
28 | "906"
29 | ],
30 | [
31 | "529",
32 | "633"
33 | ],
34 | [
35 | "418",
36 | "507"
37 | ],
38 | [
39 | "227",
40 | "169"
41 | ],
42 | [
43 | "389",
44 | "480"
45 | ],
46 | [
47 | "821",
48 | "812"
49 | ],
50 | [
51 | "670",
52 | "661"
53 | ],
54 | [
55 | "158",
56 | "379"
57 | ],
58 | [
59 | "423",
60 | "421"
61 | ],
62 | [
63 | "352",
64 | "319"
65 | ],
66 | [
67 | "579",
68 | "701"
69 | ],
70 | [
71 | "488",
72 | "399"
73 | ],
74 | [
75 | "695",
76 | "422"
77 | ],
78 | [
79 | "288",
80 | "321"
81 | ],
82 | [
83 | "705",
84 | "707"
85 | ],
86 | [
87 | "306",
88 | "278"
89 | ],
90 | [
91 | "865",
92 | "739"
93 | ],
94 | [
95 | "995",
96 | "233"
97 | ],
98 | [
99 | "755",
100 | "759"
101 | ],
102 | [
103 | "467",
104 | "462"
105 | ],
106 | [
107 | "314",
108 | "347"
109 | ],
110 | [
111 | "741",
112 | "731"
113 | ],
114 | [
115 | "970",
116 | "973"
117 | ],
118 | [
119 | "634",
120 | "660"
121 | ],
122 | [
123 | "494",
124 | "445"
125 | ],
126 | [
127 | "706",
128 | "479"
129 | ],
130 | [
131 | "186",
132 | "170"
133 | ],
134 | [
135 | "176",
136 | "190"
137 | ],
138 | [
139 | "380",
140 | "358"
141 | ],
142 | [
143 | "214",
144 | "255"
145 | ],
146 | [
147 | "454",
148 | "527"
149 | ],
150 | [
151 | "425",
152 | "485"
153 | ],
154 | [
155 | "388",
156 | "308"
157 | ],
158 | [
159 | "384",
160 | "932"
161 | ],
162 | [
163 | "035",
164 | "036"
165 | ],
166 | [
167 | "257",
168 | "420"
169 | ],
170 | [
171 | "924",
172 | "917"
173 | ],
174 | [
175 | "114",
176 | "102"
177 | ],
178 | [
179 | "732",
180 | "691"
181 | ],
182 | [
183 | "550",
184 | "452"
185 | ],
186 | [
187 | "280",
188 | "249"
189 | ],
190 | [
191 | "842",
192 | "714"
193 | ],
194 | [
195 | "625",
196 | "650"
197 | ],
198 | [
199 | "024",
200 | "073"
201 | ],
202 | [
203 | "044",
204 | "945"
205 | ],
206 | [
207 | "896",
208 | "128"
209 | ],
210 | [
211 | "862",
212 | "047"
213 | ],
214 | [
215 | "607",
216 | "683"
217 | ],
218 | [
219 | "517",
220 | "521"
221 | ],
222 | [
223 | "682",
224 | "669"
225 | ],
226 | [
227 | "138",
228 | "142"
229 | ],
230 | [
231 | "552",
232 | "851"
233 | ],
234 | [
235 | "376",
236 | "381"
237 | ],
238 | [
239 | "000",
240 | "003"
241 | ],
242 | [
243 | "048",
244 | "029"
245 | ],
246 | [
247 | "724",
248 | "725"
249 | ],
250 | [
251 | "608",
252 | "675"
253 | ],
254 | [
255 | "386",
256 | "154"
257 | ],
258 | [
259 | "220",
260 | "219"
261 | ],
262 | [
263 | "801",
264 | "855"
265 | ],
266 | [
267 | "161",
268 | "141"
269 | ],
270 | [
271 | "949",
272 | "868"
273 | ],
274 | [
275 | "880",
276 | "135"
277 | ],
278 | [
279 | "429",
280 | "404"
281 | ]
282 | ]
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas
3 | from PIL import Image
4 |
5 | from utils.dataloader import Dataset
6 | from make_train_test import *
7 | from meso.meso import *
8 |
9 |
10 | def load_datas(src_path, files=[]):
11 | datas = []
12 | for file in files:
13 | img = Image.open(os.path.join(src_path, file)).convert('RGB')
14 | img.save("./images/" + file)
15 | img = img.resize((64, 64), Image.ANTIALIAS)
16 | data = np.array(img)
17 | data = np.transpose(data, (2, 0, 1))
18 | datas.append(data)
19 | return np.array(datas)
20 |
21 |
22 | def video_frame_face_extractor(path, output):
23 | import dlib
24 | face_detector = dlib.cnn_face_detection_model_v1('./mmod_human_face_detector.dat')
25 | video_fd = cv2.VideoCapture(path)
26 | if not video_fd.isOpened():
27 | print('Skpped: {}'.format(path))
28 |
29 | frame_index = 0
30 | success, frame = video_fd.read()
31 | while success:
32 | frame_path = os.path.join(output + '/frame/%s_%d.jpg' % (path.split('/')[-1], frame_index))
33 | cv2.imwrite(frame_path, frame)
34 | img_path = os.path.join(output + '/face/%s_%d.jpg' % (path.split('/')[-1], frame_index))
35 | height, width = frame.shape[:2]
36 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
37 | faces = face_detector(gray, 1)
38 | if len(faces):
39 | # For now only take biggest face
40 | face = faces[0].rect
41 | x, y, size = get_boundingbox(face, width, height)
42 | # generate cropped image
43 | cropped_face = frame[y:y + size, x:x + size]
44 | cv2.imwrite(img_path, cropped_face)
45 |
46 | frame_index += 1
47 | success, frame = video_fd.read()
48 |
49 | video_fd.release()
50 |
51 |
52 | def list_file(path, label):
53 | list = []
54 | for file in os.listdir(path):
55 | list.append([path + '/' + file, label])
56 |
57 | return list
58 |
59 |
60 | def dataset_size(path):
61 | for file in os.listdir(path + '/0/'):
62 | img = cv2.imread(path + '/0/' + file)
63 | print(np.array(img).shape)
64 |
65 |
66 | def frame_range(src_dir):
67 | Celeb_real = list_file(src_dir + '/Celeb-real', 1)
68 | Celeb_synthesis = list_file(src_dir + '/Celeb-synthesis', 0)
69 | YouTube_real = list_file(src_dir + '/YouTube-real', 1)
70 |
71 | frame_m = []
72 |
73 | for [file, _] in Celeb_real:
74 | video = cv2.VideoCapture(os.path.join(src_dir, '/Celeb-real', file))
75 | frame_num = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
76 | frame_m.append(frame_num)
77 | print(file)
78 | print(frame_num)
79 | print('---------------')
80 | for [file, _] in Celeb_synthesis:
81 | video = cv2.VideoCapture(os.path.join(src_dir, '/Celeb-synthesis', file))
82 | frame_num = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
83 | frame_m.append(frame_num)
84 | print(file)
85 | print(frame_num)
86 | print('---------------')
87 | for [file, _] in YouTube_real:
88 | video = cv2.VideoCapture(os.path.join(src_dir, '/YouTube-real', file))
89 | frame_num = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
90 | frame_m.append(frame_num)
91 | print(file)
92 | print(frame_num)
93 | print('---------------')
94 | frame_m = np.array(frame_m)
95 | print('====================')
96 | print(np.max(np.array(frame_m)))
97 | print(np.min(np.array(frame_m)))
98 | print(frame_m.shape)
99 | print(np.median(frame_m))
100 | print(np.mean(frame_m))
101 |
102 |
103 | def build_AUC_loss(outputs, labels, gamma, power_p):
104 | posi_idx = tf.where(tf.equal(labels, 1.0))
105 | neg_idx = tf.where(tf.equal(labels, -1.0))
106 | prdictions = tf.nn.softmax(outputs)
107 | posi_predict = tf.gather(prdictions, posi_idx)
108 | posi_size = tf.shape(posi_predict)[0]
109 | neg_predict = tf.gather(prdictions, neg_idx)
110 | neg_size = tf.shape(posi_predict)[0]
111 | posi_neg_diff = tf.reshape(
112 | -(tf.matmul(posi_predict, tf.ones([1, neg_size])) -
113 | tf.matmul(tf.ones([posi_size, 1]), tf.reshape(neg_predict, [-1, neg_size])) - gamma),
114 | [-1, 1])
115 | posi_neg_diff = tf.where(tf.greater(posi_neg_diff, 0), posi_neg_diff, tf.zeros([posi_size * neg_size, 1]))
116 | posi_neg_diff = tf.pow(posi_neg_diff, power_p)
117 | loss_approx_auc = tf.reduce_mean(posi_neg_diff)
118 | return loss_approx_auc
119 |
120 |
121 | def auc_loss(y_pred, y_true, gamma, p=2):
122 | pos = tf.boolean_mask(y_pred, tf.cast(y_true, tf.bool))
123 | neg = tf.boolean_mask(y_pred, ~tf.cast(y_true, tf.bool))
124 | pos = tf.expand_dims(pos, 0)
125 | neg = tf.expand_dims(neg, 1)
126 | difference = tf.zeros_like(pos * neg) + pos - neg - gamma
127 | masked = tf.boolean_mask(difference, difference < 0.0)
128 | return tf.reduce_sum(tf.pow(-masked, p))
129 |
130 |
131 | def AUC_loss(y_pred, y_true, device, gamma, p=2):
132 | pred = torch.sigmoid(y_pred)
133 | pos = pred[torch.where(y_true == 0)]
134 | neg = pred[torch.where(y_true == 1)]
135 | pos = torch.unsqueeze(pos, 0)
136 | neg = torch.unsqueeze(neg, 1)
137 | diff = torch.zeros_like(pos * neg, device=device) + pos - neg - gamma
138 | masked = diff[torch.where(diff < 0.0)]
139 | return torch.mean(torch.pow(-masked, p))
140 |
141 |
142 | # def AUC_loss(outputs, labels, device, gamma, p=2):
143 | # predictions = torch.sigmoid(outputs)
144 | # pos_predict = predictions[torch.where(labels == 0)]
145 | # neg_predict = predictions[torch.where(labels == 1)]
146 | # pos_size = pos_predict.shape[0]
147 | # neg_size = neg_predict.shape[0]
148 | # # if pos_size == 0 or neg_size == 0:
149 | # # return 0
150 | # # else:
151 | # if pos_size != 0 and neg_size != 0:
152 | # pos_neg_diff = -(torch.matmul(pos_predict, torch.ones([1, neg_size], device=device)) -
153 | # torch.matmul(torch.ones([pos_size, 1], device=device),
154 | # torch.reshape(neg_predict, [-1, neg_size]))
155 | # - gamma)
156 | # pos_neg_diff = torch.reshape(pos_neg_diff, [-1, 1])
157 | # pos_neg_diff = torch.where(torch.gt(pos_neg_diff, 0), pos_neg_diff, torch.zeros([pos_size * neg_size, 1],
158 | # device=device))
159 | # elif neg_size == 0:
160 | # pos_neg_diff = -(pos_predict - gamma)
161 | # pos_neg_diff = torch.where(torch.gt(pos_neg_diff, 0), pos_neg_diff, torch.zeros([pos_size, 1], device=device))
162 | # else:
163 | # pos_neg_diff = -(-neg_predict - gamma)
164 | # pos_neg_diff = torch.where(torch.gt(pos_neg_diff, 0), pos_neg_diff, torch.zeros([neg_size, 1], device=device))
165 | #
166 | # pos_neg_diff = torch.pow(pos_neg_diff, p)
167 | #
168 | # loss_approx_auc = torch.mean(pos_neg_diff)
169 | # return loss_approx_auc
170 |
171 |
172 | # def AUC_loss(y_, y, device, gamma, p=2):
173 | # X = y_[torch.where(y == 0)]
174 | # Y = y_[torch.where(y == 1)]
175 | # loss = torch.zeros(1, requires_grad=True, device=device)
176 | # if X.shape[0] == 0:
177 | # Y = torch.max(Y, 1)[0]
178 | # for j in Y:
179 | # if -j < gamma:
180 | # loss = (-(- j - gamma)) ** p + loss
181 | # if Y.shape[0] == 0:
182 | # X = torch.max(X, 1)[0]
183 | # for i in X:
184 | # if i < gamma:
185 | # loss = (-(i - gamma)) ** p + loss
186 | # if X.shape[0] != 0 and Y.shape[0] != 0:
187 | # X = torch.max(X, 1)[0]
188 | # Y = torch.max(Y, 1)[0]
189 | # for i in X:
190 | # for j in Y:
191 | # if i - j < gamma:
192 | # loss = (-(i - j - gamma)) ** p + loss
193 | # return loss
194 |
195 |
196 | def merge_labels_to_ckpt(ck_path: str, train_file: str):
197 | """Merge labels to a checkpoint file.
198 |
199 | Args:
200 | ck_path(str): path to checkpoint file
201 | train_file(str): path to train set index file, eg. train.csv
202 |
203 | Return:
204 | This function will create a {ck_path}_patched.pth file.
205 | """
206 | # load model
207 | print('Loading checkpoint')
208 | ckpt = torch.load(ck_path)
209 |
210 | # load train files
211 | print('Loading dataset')
212 | raw_data = pandas.read_csv(train_file)
213 | train_set = Dataset(raw_data.to_numpy())
214 |
215 | # patch file name
216 | print('Patching')
217 | patch_path = ck_path.replace('.pth', '') + '_patched.pth'
218 |
219 | ck_dict = {'label_map': train_set.labels}
220 | names = ['epoch', 'model_state_dict', 'optimizer_state_dict']
221 | for name in names:
222 | ck_dict[name] = ckpt[name]
223 |
224 | torch.save(ck_dict, patch_path)
225 | print('Patched checkpoint has been saved to {}'.format(patch_path))
226 |
227 |
228 | def tensor2im(input_image, imtype=np.uint8):
229 | """"将tensor的数据类型转成numpy类型,并反归一化.
230 |
231 | Parameters:
232 | input_image (tensor) -- 输入的图像tensor数组
233 | imtype (type) -- 转换后的numpy的数据类型
234 | """
235 | mean = [0.485, 0.456, 0.406] # 自己设置的
236 | std = [0.229, 0.224, 0.225] # 自己设置的
237 | if not isinstance(input_image, np.ndarray):
238 | if isinstance(input_image, torch.Tensor): # get the data from a variable
239 | image_tensor = input_image.data
240 | else:
241 | return input_image
242 | image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array
243 | if image_numpy.shape[0] == 1: # grayscale to RGB
244 | image_numpy = np.tile(image_numpy, (3, 1, 1))
245 | for i in range(len(mean)):
246 | image_numpy[i] = image_numpy[i] * std[i] + mean[i]
247 | image_numpy = image_numpy * 255
248 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) # post-processing: tranpose and scaling
249 | else: # if it is a numpy array, do nothing
250 | image_numpy = input_image
251 | return image_numpy.astype(imtype)
252 |
253 |
254 | def save_img(im, path):
255 | """im可是没经过任何处理的tensor类型的数据,将数据存储到path中
256 |
257 | Parameters:
258 | im (tensor) -- 输入的图像tensor数组
259 | path (str) -- 图像保存的路径
260 | size (int) -- 一行有size张图,最好是2的倍数
261 | """
262 | # im_grid = torchvision.utils.make_grid(im, size) #将batchsize的图合成一张图
263 | im_numpy = tensor2im(im) # 转成numpy类型并反归一化
264 | im_array = Image.fromarray(im_numpy)
265 | im_array.save(path)
266 |
267 |
268 | def new_path(path):
269 | if not os.path.exists(path):
270 | try:
271 | os.mkdir(path)
272 | except Exception:
273 | os.makedirs(path)
274 |
275 |
276 | def read_npy(src):
277 | arr = np.load(src)
278 | for i in arr:
279 | print(i)
280 |
281 |
282 | def parse_args():
283 | parser = argparse.ArgumentParser(usage='python3 tools.py -i path/to/train.csv -r path/to/checkpoint')
284 | parser.add_argument('-i', '--data_path', help='path to your dataset index file')
285 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint', default=None)
286 | args = parser.parse_args()
287 | return args
288 |
289 |
290 | if __name__ == '__main__':
291 | args = parse_args()
292 | # xcp = '/home/asus/Code/checkpoint/ff/xcept/nb-model_type-xception_ep-19.pth'
293 | # meso = ''
294 | # msi = '/home/asus/Code/checkpoint/ff/msin/weights.h5'
295 | # cap = '/home/asus/Code/checkpoint/ff/cap/capsule_18.pt'
296 | # model_pred('/home/asus/ffdf/test/1', 'xception', xcp)
297 | # model_pred('/home/asus/ffdf_40/test/1', 'cap', cap)
298 | # model_pred('/home/asus/ffdf/test/1', 'msi', msi)
299 | # model_pred('/home/asus/ffdf_40/test/0', 'msi', msi)
300 | read_npy('/Users/pu/Downloads/images/c23/0/tcap.txt.npy')
301 | read_npy('/Users/pu/Downloads/images/c23/0/tmsi.txt.npy')
302 | read_npy('/Users/pu/Downloads/images/c23/0/txcep.txt.npy')
303 | print('=================')
304 | read_npy('/Users/pu/Downloads/images/c23/1/tcap.txt.npy')
305 | read_npy('/Users/pu/Downloads/images/c23/1/tmsi.txt.npy')
306 | read_npy('/Users/pu/Downloads/images/c23/1/txcep.txt.npy')
307 | print('=================')
308 | read_npy('/Users/pu/Downloads/images/c40/0/tcap.txt.npy')
309 | read_npy('/Users/pu/Downloads/images/c40/0/tmsi.txt.npy')
310 | read_npy('/Users/pu/Downloads/images/c40/0/txcep.txt.npy')
311 | print('=================')
312 | read_npy('/Users/pu/Downloads/images/c40/1/tcap.txt.npy')
313 | read_npy('/Users/pu/Downloads/images/c40/1/tmsi.txt.npy')
314 | read_npy('/Users/pu/Downloads/images/c40/1/txcep.txt.npy')
315 |
--------------------------------------------------------------------------------
/utils/train_cpvr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | from tensorboardX import SummaryWriter
5 | import torch
6 | import torch.optim as optim
7 | import torch.nn as nn
8 |
9 | from dataset import Dataset
10 | from templates import get_templates
11 |
12 | MODEL_DIR = './models/'
13 | BACKBONE = 'xcp'
14 | MAPTYPE = 'reg'
15 | BATCH_SIZE = 15
16 | MAX_EPOCHS = 100
17 | STEPS_PER_EPOCH = 1000
18 | LEARNING_RATE = 0.0001
19 | WEIGHT_DECAY = 0.001
20 |
21 | CONFIGS = {
22 | 'xcp': {
23 | 'img_size': (299, 299),
24 | 'map_size': (19, 19),
25 | 'norms': [[0.5] * 3, [0.5] * 3]
26 | },
27 | 'vgg': {
28 | 'img_size': (299, 299),
29 | 'map_size': (19, 19),
30 | 'norms': [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
31 | }
32 | }
33 | CONFIG = CONFIGS[BACKBONE]
34 |
35 | if BACKBONE == 'xcp':
36 | from xception import Model
37 | elif BACKBONE == 'vgg':
38 | from vgg import Model
39 |
40 | torch.backends.deterministic = True
41 | SEED = 1
42 | random.seed(SEED)
43 | torch.manual_seed(SEED)
44 | torch.cuda.manual_seed_all(SEED)
45 |
46 | DATA_TRAIN = Dataset('train', BATCH_SIZE, CONFIG['img_size'], CONFIG['map_size'], CONFIG['norms'], SEED)
47 |
48 | DATA_EVAL = Dataset('eval', BATCH_SIZE, CONFIG['img_size'], CONFIG['map_size'], CONFIG['norms'], SEED)
49 |
50 | TEMPLATES = None
51 | if MAPTYPE in ['tmp', 'pca_tmp']:
52 | TEMPLATES = get_templates()
53 |
54 | MODEL_NAME = '{0}_{1}'.format(BACKBONE, MAPTYPE)
55 | MODEL_DIR = MODEL_DIR + MODEL_NAME + '/'
56 |
57 | MODEL = Model(MAPTYPE, TEMPLATES, 2, False)
58 |
59 | OPTIM = optim.Adam(MODEL.model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
60 | MODEL.model.cuda()
61 | LOSS_CSE = nn.CrossEntropyLoss().cuda()
62 | LOSS_L1 = nn.L1Loss().cuda()
63 | MAXPOOL = nn.MaxPool2d(19).cuda()
64 |
65 |
66 | def calculate_losses(batch):
67 | img = batch['img']
68 | msk = batch['msk']
69 | lab = batch['lab']
70 | x, mask, vec = MODEL.model(img)
71 | loss_l1 = LOSS_L1(mask, msk)
72 | loss_cse = LOSS_CSE(x, lab)
73 | loss = loss_l1 + loss_cse
74 | pred = torch.max(x, dim=1)[1]
75 | acc = (pred == lab).float().mean()
76 | return {'loss': loss, 'loss_l1': loss_l1, 'loss_cse': loss_cse, 'acc': acc}
77 |
78 |
79 | def process_batch(batch, mode):
80 | if mode == 'train':
81 | MODEL.model.train()
82 | losses = calculate_losses(batch)
83 | OPTIM.zero_grad()
84 | losses['loss'].backward()
85 | OPTIM.step()
86 | elif mode == 'eval':
87 | MODEL.model.eval()
88 | with torch.no_grad():
89 | losses = calculate_losses(batch)
90 | return losses
91 |
92 |
93 | SUMMARY_WRITER = SummaryWriter(MODEL_DIR + 'logs/')
94 |
95 |
96 | def write_tfboard(item, itr, name):
97 | SUMMARY_WRITER.add_scalar('{0}'.format(name), item, itr)
98 |
99 |
100 | def run_step(e, s):
101 | batch = DATA_TRAIN.get_batch()
102 | losses = process_batch(batch, 'train')
103 |
104 | if s % 10 == 0:
105 | print('\r{0} - '.format(s) + ', '.join(
106 | ['{0}: {1:.3f}'.format(_, losses[_].cpu().detach().numpy()) for _ in losses]), end='')
107 | if s % 100 == 0:
108 | print('\n', end='')
109 | [write_tfboard(losses[_], e * STEPS_PER_EPOCH + s, _) for _ in losses]
110 |
111 |
112 | def run_epoch(e):
113 | print('Epoch: {0}'.format(e))
114 | for s in range(STEPS_PER_EPOCH):
115 | run_step(e, s)
116 | MODEL.save(e + 1, OPTIM, MODEL_DIR)
117 |
118 |
119 | LAST_EPOCH = 0
120 | for e in range(LAST_EPOCH, MAX_EPOCHS):
121 | run_epoch(e)
122 |
--------------------------------------------------------------------------------
/utils/val.json:
--------------------------------------------------------------------------------
1 | [
2 | [
3 | "720",
4 | "672"
5 | ],
6 | [
7 | "939",
8 | "115"
9 | ],
10 | [
11 | "284",
12 | "263"
13 | ],
14 | [
15 | "402",
16 | "453"
17 | ],
18 | [
19 | "820",
20 | "818"
21 | ],
22 | [
23 | "762",
24 | "832"
25 | ],
26 | [
27 | "834",
28 | "852"
29 | ],
30 | [
31 | "922",
32 | "898"
33 | ],
34 | [
35 | "104",
36 | "126"
37 | ],
38 | [
39 | "106",
40 | "198"
41 | ],
42 | [
43 | "159",
44 | "175"
45 | ],
46 | [
47 | "416",
48 | "342"
49 | ],
50 | [
51 | "857",
52 | "909"
53 | ],
54 | [
55 | "599",
56 | "585"
57 | ],
58 | [
59 | "443",
60 | "514"
61 | ],
62 | [
63 | "566",
64 | "617"
65 | ],
66 | [
67 | "472",
68 | "511"
69 | ],
70 | [
71 | "325",
72 | "492"
73 | ],
74 | [
75 | "816",
76 | "649"
77 | ],
78 | [
79 | "583",
80 | "558"
81 | ],
82 | [
83 | "933",
84 | "925"
85 | ],
86 | [
87 | "419",
88 | "824"
89 | ],
90 | [
91 | "465",
92 | "482"
93 | ],
94 | [
95 | "565",
96 | "589"
97 | ],
98 | [
99 | "261",
100 | "254"
101 | ],
102 | [
103 | "992",
104 | "980"
105 | ],
106 | [
107 | "157",
108 | "245"
109 | ],
110 | [
111 | "571",
112 | "746"
113 | ],
114 | [
115 | "947",
116 | "951"
117 | ],
118 | [
119 | "926",
120 | "900"
121 | ],
122 | [
123 | "493",
124 | "538"
125 | ],
126 | [
127 | "468",
128 | "470"
129 | ],
130 | [
131 | "915",
132 | "895"
133 | ],
134 | [
135 | "362",
136 | "354"
137 | ],
138 | [
139 | "440",
140 | "364"
141 | ],
142 | [
143 | "640",
144 | "638"
145 | ],
146 | [
147 | "827",
148 | "817"
149 | ],
150 | [
151 | "793",
152 | "768"
153 | ],
154 | [
155 | "837",
156 | "890"
157 | ],
158 | [
159 | "004",
160 | "982"
161 | ],
162 | [
163 | "192",
164 | "134"
165 | ],
166 | [
167 | "745",
168 | "777"
169 | ],
170 | [
171 | "299",
172 | "145"
173 | ],
174 | [
175 | "742",
176 | "775"
177 | ],
178 | [
179 | "586",
180 | "223"
181 | ],
182 | [
183 | "483",
184 | "370"
185 | ],
186 | [
187 | "779",
188 | "794"
189 | ],
190 | [
191 | "971",
192 | "564"
193 | ],
194 | [
195 | "273",
196 | "807"
197 | ],
198 | [
199 | "991",
200 | "064"
201 | ],
202 | [
203 | "664",
204 | "668"
205 | ],
206 | [
207 | "823",
208 | "584"
209 | ],
210 | [
211 | "656",
212 | "666"
213 | ],
214 | [
215 | "557",
216 | "560"
217 | ],
218 | [
219 | "471",
220 | "455"
221 | ],
222 | [
223 | "042",
224 | "084"
225 | ],
226 | [
227 | "979",
228 | "875"
229 | ],
230 | [
231 | "316",
232 | "369"
233 | ],
234 | [
235 | "091",
236 | "116"
237 | ],
238 | [
239 | "023",
240 | "923"
241 | ],
242 | [
243 | "702",
244 | "612"
245 | ],
246 | [
247 | "904",
248 | "046"
249 | ],
250 | [
251 | "647",
252 | "622"
253 | ],
254 | [
255 | "958",
256 | "956"
257 | ],
258 | [
259 | "606",
260 | "567"
261 | ],
262 | [
263 | "632",
264 | "548"
265 | ],
266 | [
267 | "927",
268 | "912"
269 | ],
270 | [
271 | "350",
272 | "349"
273 | ],
274 | [
275 | "595",
276 | "597"
277 | ],
278 | [
279 | "727",
280 | "729"
281 | ]
282 | ]
--------------------------------------------------------------------------------
/utils/xcp_reg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import os
6 | import sys
7 |
8 |
9 | class SeparableConv2d(nn.Module):
10 | def __init__(self, c_in, c_out, ks, stride=1, padding=0, dilation=1, bias=False):
11 | super(SeparableConv2d, self).__init__()
12 | self.c = nn.Conv2d(c_in, c_in, ks, stride, padding, dilation, groups=c_in, bias=bias)
13 | self.pointwise = nn.Conv2d(c_in, c_out, 1, 1, 0, 1, 1, bias=bias)
14 |
15 | def forward(self, x):
16 | x = self.c(x)
17 | x = self.pointwise(x)
18 | return x
19 |
20 |
21 | class Block(nn.Module):
22 | def __init__(self, c_in, c_out, reps, stride=1, start_with_relu=True, grow_first=True):
23 | super(Block, self).__init__()
24 |
25 | self.skip = None
26 | self.skip_bn = None
27 | if c_out != c_in or stride != 1:
28 | self.skip = nn.Conv2d(c_in, c_out, 1, stride=stride, bias=False)
29 | self.skip_bn = nn.BatchNorm2d(c_out)
30 |
31 | self.relu = nn.ReLU(inplace=True)
32 |
33 | rep = []
34 | c = c_in
35 | if grow_first:
36 | rep.append(self.relu)
37 | rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False))
38 | rep.append(nn.BatchNorm2d(c_out))
39 | c = c_out
40 |
41 | for i in range(reps - 1):
42 | rep.append(self.relu)
43 | rep.append(SeparableConv2d(c, c, 3, stride=1, padding=1, bias=False))
44 | rep.append(nn.BatchNorm2d(c))
45 |
46 | if not grow_first:
47 | rep.append(self.relu)
48 | rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False))
49 | rep.append(nn.BatchNorm2d(c_out))
50 |
51 | if not start_with_relu:
52 | rep = rep[1:]
53 | else:
54 | rep[0] = nn.ReLU(inplace=False)
55 |
56 | if stride != 1:
57 | rep.append(nn.MaxPool2d(3, stride, 1))
58 | self.rep = nn.Sequential(*rep)
59 |
60 | def forward(self, inp):
61 | x = self.rep(inp)
62 |
63 | if self.skip is not None:
64 | y = self.skip(inp)
65 | y = self.skip_bn(y)
66 | else:
67 | y = inp
68 |
69 | x += y
70 | return x
71 |
72 |
73 | class RegressionMap(nn.Module):
74 | def __init__(self, c_in):
75 | super(RegressionMap, self).__init__()
76 | self.c = SeparableConv2d(c_in, 1, 3, stride=1, padding=1, bias=False)
77 | self.s = nn.Sigmoid()
78 |
79 | def forward(self, x):
80 | mask = self.c(x)
81 | mask = self.s(mask)
82 | return mask, None
83 |
84 |
85 | class TemplateMap(nn.Module):
86 | def __init__(self, c_in, templates):
87 | super(TemplateMap, self).__init__()
88 | self.c = Block(c_in, 364, 2, 2, start_with_relu=True, grow_first=False)
89 | self.l = nn.Linear(364, 10)
90 | self.relu = nn.ReLU(inplace=True)
91 |
92 | self.templates = templates
93 |
94 | def forward(self, x):
95 | v = self.c(x)
96 | v = self.relu(v)
97 | v = F.adaptive_avg_pool2d(v, (1, 1))
98 | v = v.view(v.size(0), -1)
99 | v = self.l(v)
100 | mask = torch.mm(v, self.templates.reshape(10, 361))
101 | mask = mask.reshape(x.shape[0], 1, 19, 19)
102 |
103 | return mask, v
104 |
105 |
106 | class PCATemplateMap(nn.Module):
107 | def __init__(self, templates):
108 | super(PCATemplateMap, self).__init__()
109 | self.templates = templates
110 |
111 | def forward(self, x):
112 | fe = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
113 | fe = torch.transpose(fe, 1, 2)
114 | mu = torch.mean(fe, 2, keepdim=True)
115 | fea_diff = fe - mu
116 |
117 | cov_fea = torch.bmm(fea_diff, torch.transpose(fea_diff, 1, 2))
118 | B = self.templates.reshape(1, 10, 361).repeat(x.shape[0], 1, 1)
119 | D = torch.bmm(torch.bmm(B, cov_fea), torch.transpose(B, 1, 2))
120 | eigen_value, eigen_vector = D.symeig(eigenvectors=True)
121 | index = torch.tensor([9]).cuda()
122 | eigen = torch.index_select(eigen_vector, 2, index)
123 |
124 | v = eigen.squeeze(-1)
125 | mask = torch.mm(v, self.templates.reshape(10, 361))
126 | mask = mask.reshape(x.shape[0], 1, 19, 19)
127 | return mask, v
128 |
129 |
130 | class Xception(nn.Module):
131 | """
132 | Xception optimized for the ImageNet dataset, as specified in
133 | https://arxiv.org/pdf/1610.02357.pdf
134 | """
135 |
136 | def __init__(self, maptype, templates, num_classes=1000):
137 | super(Xception, self).__init__()
138 | self.num_classes = num_classes
139 |
140 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
141 | self.bn1 = nn.BatchNorm2d(32)
142 | self.relu = nn.ReLU(inplace=True)
143 |
144 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
145 | self.bn2 = nn.BatchNorm2d(64)
146 |
147 | self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
148 | self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
149 | self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
150 | self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
151 | self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
152 | self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
153 | self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
154 | self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
155 | self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
156 | self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
157 | self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
158 | self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
159 |
160 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
161 | self.bn3 = nn.BatchNorm2d(1536)
162 |
163 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
164 | self.bn4 = nn.BatchNorm2d(2048)
165 |
166 | self.last_linear = nn.Linear(2048, num_classes)
167 |
168 | if maptype == 'none':
169 | self.map = [1, None]
170 | elif maptype == 'reg':
171 | self.map = RegressionMap(728)
172 | elif maptype == 'tmp':
173 | self.map = TemplateMap(728)
174 | elif maptype == 'pca_tmp':
175 | self.map = PCATemplateMap(728)
176 | else:
177 | print('Unknown map type: `{0}`'.format(maptype))
178 | sys.exit()
179 |
180 | def features(self, input):
181 | x = self.conv1(input)
182 | x = self.bn1(x)
183 | x = self.relu(x)
184 |
185 | x = self.conv2(x)
186 | x = self.bn2(x)
187 | x = self.relu(x)
188 |
189 | x = self.block1(x)
190 | x = self.block2(x)
191 | x = self.block3(x)
192 | x = self.block4(x)
193 | x = self.block5(x)
194 | x = self.block6(x)
195 | x = self.block7(x)
196 | mask, vec = self.map(x)
197 | x = x * mask
198 | x = self.block8(x)
199 | x = self.block9(x)
200 | x = self.block10(x)
201 | x = self.block11(x)
202 | x = self.block12(x)
203 | x = self.conv3(x)
204 | x = self.bn3(x)
205 | x = self.relu(x)
206 |
207 | x = self.conv4(x)
208 | x = self.bn4(x)
209 | return x, mask, vec
210 |
211 | def logits(self, features):
212 | x = self.relu(features)
213 | x = F.adaptive_avg_pool2d(x, (1, 1))
214 | x = x.view(x.size(0), -1)
215 | x = self.last_linear(x)
216 | return x
217 |
218 | def forward(self, input):
219 | x, mask, vec = self.features(input)
220 | x = self.logits(x)
221 | return x, mask, vec
222 |
223 |
224 | def init_weights(m):
225 | classname = m.__class__.__name__
226 | if classname.find('SeparableConv2d') != -1:
227 | m.c.weight.data.normal_(0.0, 0.01)
228 | if m.c.bias is not None:
229 | m.c.bias.data.fill_(0)
230 | m.pointwise.weight.data.normal_(0.0, 0.01)
231 | if m.pointwise.bias is not None:
232 | m.pointwise.bias.data.fill_(0)
233 | elif classname.find('Conv') != -1 or classname.find('Linear') != -1:
234 | m.weight.data.normal_(0.0, 0.01)
235 | if m.bias is not None:
236 | m.bias.data.fill_(0)
237 | elif classname.find('BatchNorm') != -1:
238 | m.weight.data.normal_(1.0, 0.01)
239 | m.bias.data.fill_(0)
240 | elif classname.find('LSTM') != -1:
241 | for i in m._parameters:
242 | if i.__class__.__name__.find('weight') != -1:
243 | i.data.normal_(0.0, 0.01)
244 | elif i.__class__.__name__.find('bias') != -1:
245 | i.bias.data.fill_(0)
246 |
247 |
248 | class Model:
249 | def __init__(self, maptype='None', templates=None, num_classes=2, load_pretrain=True):
250 | model = Xception(maptype, templates, num_classes=num_classes)
251 | if load_pretrain:
252 | state_dict = torch.load('./xception-b5690688.pth')
253 | for name, weights in state_dict:
254 | if 'pointwise' in name:
255 | state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
256 | del state_dict['fc.weight']
257 | del state_dict['fc.bias']
258 | model.load_state_dict(state_dict, False)
259 | else:
260 | model.apply(init_weights)
261 | self.model = model
262 |
263 | def save(self, epoch, optim, model_dir):
264 | state = {'net': self.model.state_dict(), 'optim': optim.state_dict()}
265 | torch.save(state, '{0}/{1:06d}.tar'.format(model_dir, epoch))
266 | print('Saved model `{0}`'.format(epoch))
267 |
268 | def load(self, epoch, model_dir):
269 | filename = '{0}{1:06d}.tar'.format(model_dir, epoch)
270 | print('Loading model from {0}'.format(filename))
271 | if os.path.exists(filename):
272 | state = torch.load(filename)
273 | self.model.load_state_dict(state['net'])
274 | else:
275 | print('Failed to load model from {0}'.format(filename))
276 |
--------------------------------------------------------------------------------
/xception/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PWB97/Deepfake-detection/ce56e6c23ee12ea589a9df123604fb1f11e20246/xception/__init__.py
--------------------------------------------------------------------------------
/xception/models.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Author: Andreas Rössler
4 | """
5 | import os, sys
6 | # sys.path.append('../')
7 | import argparse
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from xception.xception import xception
13 | import math
14 | import torchvision
15 |
16 |
17 | def return_pytorch04_xception(pretrained=True):
18 | # Raises warning "src not broadcastable to dst" but thats fine
19 | model = xception(pretrained=False)
20 | if pretrained:
21 | # Load model in torch 0.4+
22 | model.fc = model.last_linear
23 | del model.last_linear
24 | # import pdb; pdb.set_trace()
25 | state_dict = torch.load(os.path.dirname(__file__) + '/xception.pth')
26 | # './trained_model/xception.pth')
27 | for name, weights in state_dict.items():
28 | if 'pointwise' in name:
29 | state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
30 | model.load_state_dict(state_dict)
31 | model.last_linear = model.fc
32 | del model.fc
33 | return model
34 |
35 |
36 | class TransferModel(nn.Module):
37 | """
38 | Simple transfer learning model that takes an imagenet pretrained model with
39 | a fc layer as base model and retrains a new fc layer for num_out_classes
40 | """
41 |
42 | def __init__(self, modelchoice, num_out_classes=2, dropout=0.0):
43 | super(TransferModel, self).__init__()
44 | self.modelchoice = modelchoice
45 | if modelchoice == 'xception':
46 | self.model = return_pytorch04_xception()
47 | # Replace fc
48 | num_ftrs = self.model.last_linear.in_features
49 | if not dropout:
50 | self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
51 | else:
52 | print('Using dropout', dropout)
53 | self.model.last_linear = nn.Sequential(
54 | nn.Dropout(p=dropout),
55 | nn.Linear(num_ftrs, num_out_classes)
56 | )
57 | elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
58 | if modelchoice == 'resnet50':
59 | self.model = torchvision.models.resnet50(pretrained=True)
60 | if modelchoice == 'resnet18':
61 | self.model = torchvision.models.resnet18(pretrained=True)
62 | # Replace fc
63 | num_ftrs = self.model.fc.in_features
64 | if not dropout:
65 | self.model.fc = nn.Linear(num_ftrs, num_out_classes)
66 | else:
67 | self.model.fc = nn.Sequential(
68 | nn.Dropout(p=dropout),
69 | nn.Linear(num_ftrs, num_out_classes)
70 | )
71 | else:
72 | raise Exception('Choose valid model, e.g. resnet50')
73 |
74 | def set_trainable_up_to(self, boolean, layername="Conv2d_4a_3x3"):
75 | """
76 | Freezes all layers below a specific layer and sets the following layers
77 | to true if boolean else only the fully connected final layer
78 | :param boolean:
79 | :param layername: depends on network, for inception e.g. Conv2d_4a_3x3
80 | :return:
81 | """
82 | # Stage-1: freeze all the layers
83 | if layername is None:
84 | for i, param in self.model.named_parameters():
85 | param.requires_grad = True
86 | return
87 | else:
88 | for i, param in self.model.named_parameters():
89 | param.requires_grad = False
90 | if boolean:
91 | # Make all layers following the layername layer trainable
92 | ct = []
93 | found = False
94 | for name, child in self.model.named_children():
95 | if layername in ct:
96 | found = True
97 | for params in child.parameters():
98 | params.requires_grad = True
99 | ct.append(name)
100 | if not found:
101 | raise Exception('Layer not found, cant finetune!'.format(
102 | layername))
103 | else:
104 | if self.modelchoice == 'xception':
105 | # Make fc trainable
106 | for param in self.model.last_linear.parameters():
107 | param.requires_grad = True
108 |
109 | else:
110 | # Make fc trainable
111 | for param in self.model.fc.parameters():
112 | param.requires_grad = True
113 |
114 | def forward(self, x):
115 | x = self.model(x)
116 | return x
117 |
118 |
119 | def model_selection(modelname, num_out_classes,
120 | dropout=None):
121 | """
122 | :param modelname:
123 | :return: model, image size, pretraining, input_list
124 | """
125 | if modelname == 'xception':
126 | return TransferModel(modelchoice='xception',
127 | num_out_classes=num_out_classes), 299, \
128 | True, ['image'], None
129 | elif modelname == 'resnet18':
130 | return TransferModel(modelchoice='resnet18', dropout=dropout,
131 | num_out_classes=num_out_classes), \
132 | 224, True, ['image'], None
133 | else:
134 | raise NotImplementedError(modelname)
135 |
136 |
137 | if __name__ == '__main__':
138 | model, image_size, *_ = model_selection('resnet18', num_out_classes=2)
139 | print(model)
140 | model = model.cuda()
141 | from torchsummary import summary
142 |
143 | input_s = (3, image_size, image_size)
144 | print(summary(model, input_s))
145 |
--------------------------------------------------------------------------------
/xception/xception.py:
--------------------------------------------------------------------------------
1 | """
2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)
3 |
4 | @author: tstandley
5 | Adapted by cadene
6 |
7 | Creates an Xception Model as defined in:
8 |
9 | Francois Chollet
10 | Xception: Deep Learning with Depthwise Separable Convolutions
11 | https://arxiv.org/pdf/1610.02357.pdf
12 |
13 | This weights ported from the Keras implementation. Achieves the following performance on the validation set:
14 |
15 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292
16 |
17 | REMEMBER to set your image size to 3x299x299 for both test and validation
18 |
19 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
20 | std=[0.5, 0.5, 0.5])
21 |
22 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
23 | """
24 | import math
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.functional as F
28 | import torch.utils.model_zoo as model_zoo
29 | from torch.nn import init
30 |
31 | pretrained_settings = {
32 | 'xception': {
33 | 'imagenet': {
34 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
35 | 'input_space': 'RGB',
36 | 'input_size': [3, 299, 299],
37 | 'input_range': [0, 1],
38 | 'mean': [0.5, 0.5, 0.5],
39 | 'std': [0.5, 0.5, 0.5],
40 | 'num_classes': 1000,
41 | 'scale': 0.8975
42 | # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
43 | }
44 | }
45 | }
46 |
47 |
48 | class SeparableConv2d(nn.Module):
49 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
50 | super(SeparableConv2d, self).__init__()
51 |
52 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels,
53 | bias=bias)
54 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
55 |
56 | def forward(self, x):
57 | x = self.conv1(x)
58 | x = self.pointwise(x)
59 | return x
60 |
61 |
62 | class Block(nn.Module):
63 | def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
64 | super(Block, self).__init__()
65 |
66 | if out_filters != in_filters or strides != 1:
67 | self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
68 | self.skipbn = nn.BatchNorm2d(out_filters)
69 | else:
70 | self.skip = None
71 |
72 | self.relu = nn.ReLU(inplace=True)
73 | rep = []
74 |
75 | filters = in_filters
76 | if grow_first:
77 | rep.append(self.relu)
78 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
79 | rep.append(nn.BatchNorm2d(out_filters))
80 | filters = out_filters
81 |
82 | for i in range(reps - 1):
83 | rep.append(self.relu)
84 | rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
85 | rep.append(nn.BatchNorm2d(filters))
86 |
87 | if not grow_first:
88 | rep.append(self.relu)
89 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
90 | rep.append(nn.BatchNorm2d(out_filters))
91 |
92 | if not start_with_relu:
93 | rep = rep[1:]
94 | else:
95 | rep[0] = nn.ReLU(inplace=False)
96 |
97 | if strides != 1:
98 | rep.append(nn.MaxPool2d(3, strides, 1))
99 | self.rep = nn.Sequential(*rep)
100 |
101 | def forward(self, inp):
102 | x = self.rep(inp)
103 |
104 | if self.skip is not None:
105 | skip = self.skip(inp)
106 | skip = self.skipbn(skip)
107 | else:
108 | skip = inp
109 |
110 | x += skip
111 | return x
112 |
113 |
114 | class Xception(nn.Module):
115 | """
116 | Xception optimized for the ImageNet dataset, as specified in
117 | https://arxiv.org/pdf/1610.02357.pdf
118 | """
119 |
120 | def __init__(self, num_classes=1000):
121 | """ Constructor
122 | Args:
123 | num_classes: number of classes
124 | """
125 | super(Xception, self).__init__()
126 | self.num_classes = num_classes
127 |
128 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
129 | self.bn1 = nn.BatchNorm2d(32)
130 | self.relu = nn.ReLU(inplace=True)
131 |
132 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
133 | self.bn2 = nn.BatchNorm2d(64)
134 | # do relu here
135 |
136 | self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
137 | self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
138 | self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
139 |
140 | self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
141 | self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
142 | self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
143 | self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
144 |
145 | self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
146 | self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
147 | self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
148 | self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
149 |
150 | self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
151 |
152 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
153 | self.bn3 = nn.BatchNorm2d(1536)
154 |
155 | # do relu here
156 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
157 | self.bn4 = nn.BatchNorm2d(2048)
158 |
159 | self.fc = nn.Linear(2048, num_classes)
160 |
161 | # #------- init weights --------
162 | # for m in self.modules():
163 | # if isinstance(m, nn.Conv2d):
164 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
165 | # m.weight.data.normal_(0, math.sqrt(2. / n))
166 | # elif isinstance(m, nn.BatchNorm2d):
167 | # m.weight.data.fill_(1)
168 | # m.bias.data.zero_()
169 | # #-----------------------------
170 |
171 | def features(self, input):
172 | x = self.conv1(input)
173 | x = self.bn1(x)
174 | x = self.relu(x)
175 |
176 | x = self.conv2(x)
177 | x = self.bn2(x)
178 | x = self.relu(x)
179 |
180 | x = self.block1(x)
181 | x = self.block2(x)
182 | x = self.block3(x)
183 | x = self.block4(x)
184 | x = self.block5(x)
185 | x = self.block6(x)
186 | x = self.block7(x)
187 | x = self.block8(x)
188 | x = self.block9(x)
189 | x = self.block10(x)
190 | x = self.block11(x)
191 | x = self.block12(x)
192 |
193 | x = self.conv3(x)
194 | x = self.bn3(x)
195 | x = self.relu(x)
196 |
197 | x = self.conv4(x)
198 | x = self.bn4(x)
199 | return x
200 |
201 | def logits(self, features):
202 | x = self.relu(features)
203 |
204 | x = F.adaptive_avg_pool2d(x, (1, 1))
205 | x = x.view(x.size(0), -1)
206 | x = self.last_linear(x)
207 | return x
208 |
209 | def forward(self, input):
210 | x = self.features(input)
211 | # eric
212 | # print('1', x.shape)
213 | x = self.logits(x)
214 | # eric
215 | # print('2', x.shape)
216 | return x
217 |
218 |
219 | def xception(num_classes=1000, pretrained='imagenet'):
220 | model = Xception(num_classes=num_classes)
221 | if pretrained:
222 | settings = pretrained_settings['xception'][pretrained]
223 | assert num_classes == settings['num_classes'], \
224 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
225 |
226 | model = Xception(num_classes=num_classes)
227 | model.load_state_dict(model_zoo.load_url(settings['url']))
228 |
229 | model.input_space = settings['input_space']
230 | model.input_size = settings['input_size']
231 | model.input_range = settings['input_range']
232 | model.mean = settings['mean']
233 | model.std = settings['std']
234 |
235 | # TODO: ugly
236 | model.last_linear = model.fc
237 | del model.fc
238 | return model
239 |
--------------------------------------------------------------------------------