├── M2TR_model.py ├── README.md ├── data_process.py ├── model.png ├── my_dataset.py ├── test.py ├── train.py └── utils.py /M2TR_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from efficientnet_pytorch import EfficientNet 5 | import torch.nn.functional as F 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | class Extra_fq(nn.Module): 10 | """ 11 | img_size : the size of train image 12 | out_c : the channel of feature map extracted by this block(the channel mast be same as the spatial domain feature map) 13 | """ 14 | def __init__(self, img_size, out_c): 15 | super(Extra_fq, self).__init__() 16 | self.img_size = img_size 17 | #self.batch_size = batch_size 18 | 19 | def get_DCTf(img_size): 20 | timg = torch.zeros((img_size, img_size)).to(device) 21 | 22 | M, N = torch.from_numpy(np.array(img_size)), torch.from_numpy(np.array(img_size)) 23 | timg[0, :] = 1 * torch.sqrt(1 / N) 24 | for i in range(1, M): 25 | for j in range(N): 26 | timg[i, j] = torch.cos(np.pi * i * (2 * j + 1) / (2 * N)) * torch.sqrt(2 / N) 27 | return timg 28 | 29 | 30 | 31 | def create_f(img_size): 32 | resolution = (img_size, img_size) 33 | low_f = torch.zeros(resolution).to(device) 34 | high_f = torch.ones(resolution).to(device) 35 | mid_f = torch.ones(resolution).to(device) 36 | resolution = np.array(resolution) 37 | t_1 = resolution // 16 38 | t_2 = resolution // 8 39 | for i in range(t_1[0]): 40 | for j in range(t_1[1] - i): 41 | low_f[i, j] = 1 42 | for i in range(t_2[0]): 43 | for j in range(t_2[1]-i): 44 | high_f[i, j] = 0 45 | mid_f = mid_f - low_f 46 | mid_f = mid_f - high_f 47 | return low_f, mid_f, high_f 48 | 49 | self.dct_filter = get_DCTf(img_size=img_size) 50 | 51 | self.low_f, self.mid_f, self.high_f = create_f(img_size=img_size) 52 | 53 | 54 | self.block = nn.Sequential( 55 | nn.Conv2d(3, int(out_c / 2), kernel_size=3, stride=1, padding=1), 56 | nn.BatchNorm2d(int(out_c / 2)), 57 | nn.ReLU(inplace=True), 58 | nn.MaxPool2d(kernel_size=2, stride=2), 59 | 60 | nn.Conv2d(int(out_c / 2), out_c, kernel_size=3, stride=1, padding=1), 61 | nn.BatchNorm2d(out_c), 62 | nn.ReLU(inplace=True), 63 | nn.MaxPool2d(kernel_size=2, stride=2) 64 | 65 | 66 | 67 | ) 68 | 69 | def DCT(self, img): 70 | 71 | dst = self.dct_filter * img 72 | dst = dst * self.dct_filter.permute(0, 1) 73 | return dst 74 | 75 | def IDCT(self, img): 76 | 77 | dst = self.dct_filter.permute(0, 1) * img 78 | dst = dst * self.dct_filter 79 | return dst 80 | 81 | def forward(self, img): 82 | r, g, b = img[:, 0, :, :], img[:, 1, :, :], img[:, 2, :, :] 83 | dct_1, dct_2, dct_3 = self.DCT(r), self.DCT(g), self.DCT(b) 84 | 85 | fl = [self.low_f, self.mid_f, self.high_f] 86 | re = [] 87 | for i in range(3): 88 | t_1 = dct_1 * fl[i] 89 | t_2 = dct_2 * fl[i] 90 | t_3 = dct_3 * fl[i] 91 | re.append(self.IDCT(t_1 + t_2 + t_3)) 92 | out = torch.cat((re[0].unsqueeze(1), re[1].unsqueeze(1), re[2].unsqueeze(1)), dim=1) 93 | 94 | out = self.block(out) 95 | return out 96 | 97 | 98 | class PatchEmbed(nn.Module): 99 | """ 100 | img_size : the size of feature map extracted by spatial domain 101 | patch_size : the size of the patch using on embeding 102 | in_c : the channel of the feature map extracted by spatial domain 103 | embed_dim:the embeding dimension 104 | """ 105 | def __init__(self, img_size=56, patch_size=56, in_c=32, embed_dim=768, norm_layer=None): 106 | super().__init__() 107 | img_size = (img_size, img_size) 108 | patch_size = (patch_size, patch_size) 109 | self.img_size = img_size 110 | self.patch_size = patch_size 111 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 112 | self.num_patches = self.grid_size[0] * self.grid_size[1] 113 | 114 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 115 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 116 | 117 | def forward(self, x): 118 | b, c, h, w = x.shape 119 | x = self.proj(x).flatten(2).transpose(1, 2) 120 | x = self.norm(x) 121 | return x # [b, p_num, embed_dim] 122 | 123 | 124 | class ATT(nn.Module): 125 | """ 126 | in_c : the channel of the feature map extracted by spatial domain 127 | patch_size : the size of the patch using on embeding 128 | dim : the embeding dimension 129 | attn_drop_ratio : the dropout rate of the attention 130 | proj_drop_ratio : the dropout rate of the projection 131 | """ 132 | def __init__(self, 133 | 134 | in_c, 135 | patch_size, 136 | dim, 137 | num_heads=8, 138 | qkv_bias=False, 139 | attn_drop_ratio=0., 140 | proj_drop_ratio=0.): 141 | super(ATT, self).__init__() 142 | self.num_heads = num_heads 143 | self.patch_size = patch_size 144 | # self.in_c = in_c 145 | #head_dim = dim // num_heads 146 | self.scale = (patch_size * patch_size * in_c) ** -0.5 147 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 148 | self.attn_drop = nn.Dropout(attn_drop_ratio) 149 | self.proj = nn.Linear(dim, patch_size * patch_size) 150 | self.proj_drop = nn.Dropout(proj_drop_ratio) 151 | 152 | def forward(self, x): 153 | B, N, C = x.shape 154 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 155 | q, k, v = qkv[0], qkv[1], qkv[2] 156 | 157 | attn = (q @ k.transpose(-2, -1)) * self.scale 158 | attn = attn.softmax(dim=-1) 159 | attn = self.attn_drop(attn) 160 | 161 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 162 | x = self.proj(x) 163 | x = self.proj_drop(x) 164 | x = x.reshape((B, N, self.patch_size, self.patch_size)) 165 | return x 166 | 167 | 168 | class MSMHTR(nn.Module): 169 | """ 170 | img_size : the size of feature map extracted by spatial domain 171 | in_c : the channel of feature map extracted by spatial domain 172 | n_dim : the embed dimension of the patch embed 173 | drop_ratio : the dropout rate 174 | """ 175 | def __init__(self, 176 | img_size, 177 | in_c, 178 | n_dim, 179 | drop_ratio): 180 | super(MSMHTR, self).__init__() 181 | self.patch_size = img_size 182 | self.in_c = in_c 183 | 184 | num_patches = int((img_size // self.patch_size) ** 2) 185 | # self.cls1 = nn.Parameter(torch.zeros(1, 1, n_dim)) 186 | self.pos1 = nn.Parameter(torch.zeros(1, num_patches, n_dim)) 187 | self.eb1 = PatchEmbed(patch_size=self.patch_size, embed_dim=n_dim, in_c=self.in_c) 188 | self.att1 = ATT(in_c=in_c, dim=n_dim, patch_size=self.patch_size) 189 | self.patch_size = int(self.patch_size / 2) 190 | 191 | num_patches = int((img_size // self.patch_size) ** 2) 192 | # self.cls2 = nn.Parameter(torch.zeros(1, 1, n_dim)) 193 | self.pos2 = nn.Parameter(torch.zeros(1, num_patches, n_dim)) 194 | self.eb2 = PatchEmbed(patch_size=self.patch_size, embed_dim=n_dim, in_c=self.in_c) 195 | self.att2 = ATT(in_c=in_c, dim=n_dim, patch_size=self.patch_size) 196 | self.patch_size = int(self.patch_size / 2) 197 | 198 | num_patches = int((img_size // self.patch_size) ** 2) 199 | # self.cls3 = nn.Parameter(torch.zeros(1, 1, n_dim)) 200 | self.pos3 = nn.Parameter(torch.zeros(1, num_patches, n_dim)) 201 | self.eb3 = PatchEmbed(patch_size=self.patch_size, embed_dim=n_dim, in_c=self.in_c) 202 | self.att3 = ATT(in_c=in_c, dim=n_dim, patch_size=self.patch_size) 203 | self.patch_size = int(self.patch_size / 2) 204 | 205 | num_patches = int((img_size // self.patch_size) ** 2) 206 | # self.cls4 = nn.Parameter(torch.zeros(1, 1, n_dim)) 207 | self.pos4 = nn.Parameter(torch.zeros(1, num_patches, n_dim)) 208 | self.eb4 = PatchEmbed(patch_size=self.patch_size, embed_dim=n_dim, in_c=self.in_c) 209 | self.att4 = ATT(in_c=in_c, dim=n_dim, patch_size=self.patch_size) 210 | 211 | self.pos_drop = nn.Dropout(p=drop_ratio) 212 | 213 | def forward(self, x): 214 | eb1 = self.eb1(x) 215 | input1 = self.pos_drop(eb1 + self.pos1) 216 | att1 = self.att1(input1) 217 | 218 | eb2 = self.eb2(x) 219 | input2 = self.pos_drop(eb2 + self.pos2) 220 | att2 = self.att2(input2) 221 | att2 = att2.reshape(att1.shape) 222 | 223 | eb3 = self.eb3(x) 224 | input3 = self.pos_drop(eb3 + self.pos3) 225 | att3 = self.att3(input3) 226 | att3 = att3.reshape(att1.shape) 227 | 228 | eb4 = self.eb4(x) 229 | input4 = self.pos_drop(eb4 + self.pos4) 230 | att4 = self.att4(input4) 231 | att4 = att4.reshape(att1.shape) 232 | 233 | return att1 + att2 + att3 + att4 234 | 235 | 236 | class CMF(nn.Module): 237 | """ 238 | in_c : the channel of the feature map stacked by spatial domain, frequency domain and MSMHTR 239 | img_size : the size of the feature map stacked by spatial domain, frequency domain and MSMHTR 240 | """ 241 | 242 | def __init__(self, 243 | in_c, 244 | img_size=56): 245 | super(CMF, self).__init__() 246 | self.convq = nn.Conv2d(in_c, in_c, kernel_size=1, bias=False) 247 | self.convk = nn.Conv2d(in_c, in_c, kernel_size=1, bias=False) 248 | self.convv = nn.Conv2d(in_c, in_c, kernel_size=1, bias=False) 249 | self.scale = (img_size * img_size * in_c) ** -0.5 250 | 251 | self.conv1 = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=1) 252 | 253 | def forward(self, x_s, x_fq, x_mt): 254 | q = self.convq(x_s) 255 | k = self.convk(x_fq) 256 | v = self.convv(x_fq) 257 | fuse = (q @ k.transpose(-2, -1)) * self.scale 258 | fuse = fuse.softmax(dim=-1) 259 | fuse = fuse @ v 260 | 261 | f_cmf = self.conv1(x_s + x_mt + fuse) 262 | 263 | return f_cmf 264 | 265 | 266 | class M2TR(nn.Module): 267 | """ 268 | img_size : the size of your train img 269 | n_dim : the dimension of patch embeding 270 | drop_ratio : the dropout rate 271 | """ 272 | def __init__(self, img_size, n_dim, drop_ratio): 273 | super(M2TR, self).__init__() 274 | self.model = EfficientNet.from_name('efficientnet-b4') 275 | state_dict = torch.load(r'C:\Users\satomi ishihara\za\desktop\fakeface\efficientnet-b4.pth') 276 | self.model.load_state_dict(state_dict) 277 | self.backbone1 = nn.Sequential( 278 | nn.PixelShuffle(2), 279 | nn.PixelShuffle(2), 280 | nn.PixelShuffle(2) 281 | ) 282 | 283 | self.ex_fq = Extra_fq(img_size=img_size, out_c=28) 284 | self.mt = MSMHTR(img_size=int(img_size / 4), in_c=28, n_dim=n_dim, drop_ratio=drop_ratio) 285 | 286 | self.cmf = CMF(in_c=28, img_size=img_size / 4) 287 | 288 | self.backbone2 = nn.Sequential( 289 | nn.Conv2d(28, 32, kernel_size=3, stride=1, padding=1), 290 | nn.BatchNorm2d(32), 291 | nn.ReLU(inplace=True), 292 | nn.MaxPool2d(kernel_size=2, stride=2), 293 | 294 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 295 | nn.BatchNorm2d(64), 296 | nn.ReLU(inplace=True), 297 | nn.MaxPool2d(kernel_size=2, stride=2), 298 | 299 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 300 | nn.BatchNorm2d(128), 301 | nn.ReLU(inplace=True), 302 | nn.MaxPool2d(kernel_size=2, stride=2), 303 | 304 | ) 305 | self.fc = nn.Sequential( 306 | nn.Linear(int(((img_size / 32) ** 2) * 128), 512), 307 | nn.Linear(512, 2), 308 | nn.ReLU(inplace=True) 309 | ) 310 | 311 | def feature_forward(self, x): 312 | return self.backbone1(self.model.extract_features(x)) 313 | 314 | def forward(self, x): 315 | #x_ = self.model.extract_features(x) 316 | f_s = self.feature_forward(x) 317 | f_fq = self.ex_fq(x) 318 | f_mt = self.mt(f_s) 319 | f_cmf = self.cmf(f_s, f_fq, f_mt) 320 | 321 | out = self.backbone2(f_cmf) 322 | out = torch.flatten(out, 1) 323 | out = self.fc(out) 324 | return out.softmax(dim=-1) # , f_s 325 | 326 | class FocalLoss(nn.Module): 327 | def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): 328 | super(FocalLoss, self).__init__() 329 | self.alpha = alpha 330 | self.gamma = gamma 331 | self.logits = logits 332 | self.reduce = reduce 333 | 334 | def forward(self, inputs, targets): 335 | if self.logits: 336 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 337 | else: 338 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 339 | pt = torch.exp(-BCE_loss) 340 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 341 | 342 | if self.reduce: 343 | return torch.mean(F_loss) 344 | else: 345 | return F_loss 346 | 347 | def create_FocalLoss(alpha, gamma, logits=False, reduce=True): 348 | return FocalLoss(alpha=alpha, gamma=gamma, logits=logits, reduce=reduce) 349 | 350 | 351 | def create_model(img_size=224, n_dim=768, drop_ratio=0.1): 352 | #modify hera to change your img_size, embed_dim, and drop_ratio 353 | return M2TR(img_size=img_size, n_dim=n_dim, drop_ratio=drop_ratio) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # M2TR Multi-modal Multi-scale Transformers for Deepfake Detection 2 | 3 | https://arxiv.org/pdf/2104.09770.pdf 4 | 5 | This is my first time to recode the model from the paper. If these is something mistaking or you don't know, please tell me. 6 | 7 | email: 729946969@qq.com 8 | 9 | zhihu homepage: https://www.zhihu.com/people/ishihara-32 10 | 11 | microblog homepage: https://weibo.com/u/6322632992 12 | 13 | ![1](model.png) 14 | 15 | I don't have the dataset of the paper authors, so I can't code the Decoder part. 16 | 17 | ## requirement 18 | 19 | Please run this comment to download the requirement 20 | 21 | ```shell 22 | pip install requirement.txt 23 | ``` 24 | 25 | ## using your own dataset 26 | 27 | If you want to use your own dataset , please modify the train.py on here to your own data path. 28 | 29 | ```python 30 | parser.add_argument('--data-path', type=str, 31 | default=r"C:\Users\satomi ishihara\za\desktop\fakeface\train_fake") 32 | ``` 33 | 34 | And the default validate rate is 0.2, which means that the train data will be divided to 4:1 to training and validating. If you want to change this rate , please modify the utils.py on here to your own rate. 35 | 36 | ```python 37 | def read_split_data(root: str, val_rate: float = 0.2) 38 | ``` 39 | 40 | Any parameter you can change in the train.py such as the epoch number, learning rate, and so on. Anything is ready and then you can run the train.py to start train your own model. 41 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import pandas as pd 5 | 6 | img_path = 'fackface_detect_1/深度伪造人脸检测数据集/image/train' 7 | label_path = pd.read_csv('fackface_detect_1/深度伪造人脸检测数据集/train.labels.csv') 8 | target_train_path = 'fakeface/train' 9 | target_test_path = 'fakeface/test' 10 | if not os.path.exists(target_train_path + '/0'): 11 | os.makedirs(target_train_path + '/0') 12 | if not os.path.exists(target_train_path + '/1'): 13 | os.makedirs(target_train_path + '/1') 14 | if not os.path.exists(target_test_path + '/0'): 15 | os.makedirs(target_test_path + '/0') 16 | if not os.path.exists(target_test_path + '/1'): 17 | os.makedirs(target_test_path + '/1') 18 | label_path = np.array(label_path) 19 | true_num, false_num = 0, 0 20 | for item in label_path: 21 | img_name, label = item[0][:-2], item[0][-1] 22 | if label == '0': 23 | true_num += 1 24 | else: 25 | false_num += 1 26 | test_rate = 0.1 27 | true_num_train = 0 28 | false_num_train = 0 29 | for item in label_path: 30 | img_name, label = item[0][:-2], item[0][-1] 31 | if label == '0': 32 | if true_num_train <= int(true_num * (1 - test_rate)): 33 | true_num_train += 1 34 | shutil.copy(img_path + '/' + img_name, target_train_path + '/0/' + img_name) 35 | else: 36 | shutil.copy(img_path + '/' + img_name, target_test_path + '/0/' + img_name) 37 | else: 38 | if false_num_train <= int(false_num * (1 - test_rate)): 39 | false_num_train += 1 40 | shutil.copy(img_path + '/' + img_name, target_train_path + '/1/' + img_name) 41 | else: 42 | shutil.copy(img_path + '/' + img_name, target_test_path + '/1/' + img_name) 43 | print(img_name) 44 | 45 | 46 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satomiishihara/M2TR-pytorch/89071eb10d9cbfe3810f8943a7abd78b8b76b0ec/model.png -------------------------------------------------------------------------------- /my_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class MyDataSet(Dataset): 7 | 8 | 9 | def __init__(self, images_path: list, images_class: list, transform=None): 10 | self.images_path = images_path 11 | self.images_class = images_class 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.images_path) 16 | 17 | def __getitem__(self, item): 18 | img = Image.open(self.images_path[item]) 19 | 20 | if img.mode != 'RGB': 21 | raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) 22 | label = self.images_class[item] 23 | 24 | if self.transform is not None: 25 | img = self.transform(img) 26 | 27 | return img, label 28 | 29 | @staticmethod 30 | def collate_fn(batch): 31 | 32 | images, labels = tuple(zip(*batch)) 33 | 34 | images = torch.stack(images, dim=0) 35 | labels = torch.as_tensor(labels) 36 | return images, labels 37 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | from M2TR_model import create_model 3 | 4 | import torch 5 | import os 6 | from PIL import Image 7 | from torchvision.transforms import ToTensor 8 | import re 9 | 10 | import pandas as pd 11 | import torchvision.datasets as dset 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | import tqdm 15 | from my_dataset import MyDataSet 16 | from utils import evaluate 17 | 18 | 19 | to_tensor = ToTensor() 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | loadpath = 'model-7.pth' #modify here to change your model data 22 | datapath = r'C:\Users\satomi ishihara\za\desktop\fakeface\深度伪造人脸检测数据集\image\test' 23 | testpath = r'C:\Users\satomi ishihara\za\desktop\fakeface\test_fake' 24 | savecsvpath = 'test.csv' 25 | Resize = transforms.Resize(size=(224, 224)) 26 | 27 | def test1(): 28 | model = create_model().to(device) 29 | model.load_state_dict(torch.load(loadpath)) 30 | model.eval() 31 | testlist = os.listdir(datapath) 32 | testlist.sort(key = lambda x:int(re.match('\D+(\d+)\.jpg',x).group(1))) 33 | csv_list = [] 34 | #print(testlist) 35 | for i in range(len(testlist)): 36 | testimg = Image.open(datapath +'/' + testlist[i]) 37 | testimg = to_tensor(testimg) 38 | testimg = testimg.unsqueeze(0) 39 | testimg = Resize(testimg) 40 | pred = model(testimg.to(device)) 41 | pred_classes = torch.max(pred, dim=1)[1] 42 | 43 | t = testlist[i] + '\t%d' % pred_classes 44 | print(t) 45 | csv_list.append(t) 46 | 47 | return csv_list 48 | 49 | 50 | def test2(root): 51 | model = create_model().to(device) 52 | model.load_state_dict(torch.load(loadpath)) 53 | model.eval() 54 | 55 | dataset = dset.ImageFolder(root, 56 | transform=transforms.Compose([transforms.Resize((224, 224)), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), 59 | (0.5, 0.5, 0.5)), 60 | ])) 61 | dataloader = DataLoader(dataset, 62 | shuffle=False, 63 | batch_size=8, 64 | num_workers=0) 65 | 66 | 67 | accu_num = torch.zeros(1).to(device) 68 | sample_num = 0 69 | #dataloader = tqdm(dataloader) 70 | 71 | for idx, (img, label) in enumerate(dataloader): 72 | img = img.to(device) 73 | label = label.to(device) 74 | sample_num += img.shape[0] 75 | pred = model(img) 76 | pred_classes = torch.max(pred, dim=1)[1] 77 | accu = torch.eq(pred_classes, label).sum()/img.shape[0] 78 | accu_num += accu 79 | print('step:%d, accu:%f'%(idx, accu)) 80 | #dataloader.desc = "acc: {:.3f}".format(accu_num.item() / sample_num) 81 | 82 | 83 | print(accu_num) 84 | 85 | def test3(root): 86 | model = create_model().to(device) 87 | model.load_state_dict(torch.load(loadpath)) 88 | model.eval() 89 | classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] 90 | 91 | classes.sort() 92 | 93 | class_indices = dict((k, v) for v, k in enumerate(classes)) 94 | val_images_path = [] 95 | val_images_label = [] 96 | supported = [".jpg", ".JPG", ".png", ".PNG"] 97 | for cla in classes: 98 | cla_path = os.path.join(root, cla) 99 | 100 | images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) 101 | if os.path.splitext(i)[-1] in supported] 102 | 103 | image_class = class_indices[cla] 104 | for img_path in images: 105 | val_images_path.append(img_path) 106 | val_images_label.append(image_class) 107 | 108 | val_dataset = MyDataSet(images_path=val_images_path, 109 | images_class=val_images_label, 110 | transform=transforms.Compose([transforms.Resize(224), 111 | #transforms.CenterCrop(224), 112 | transforms.ToTensor(), 113 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])) 114 | val_loader = torch.utils.data.DataLoader(val_dataset, 115 | batch_size=8, 116 | shuffle=False, 117 | pin_memory=True, 118 | num_workers=0, 119 | collate_fn=val_dataset.collate_fn) 120 | val_loss, val_acc = evaluate(model=model, 121 | data_loader=val_loader, 122 | device=device, 123 | epoch=1) 124 | 125 | 126 | 127 | csvlist = test1() 128 | csvlist = pd.DataFrame(data=csvlist) 129 | csvlist.to_csv('submion.csv', encoding='gbk',index=False,header=None) 130 | #test3(root=testpath) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torchvision import transforms 10 | 11 | 12 | from my_dataset import MyDataSet 13 | from M2TR_model import create_model 14 | from utils import read_split_data, train_one_epoch, evaluate 15 | 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | def main(args): 19 | 20 | 21 | if os.path.exists("./weights") is False: 22 | os.makedirs("./weights") 23 | 24 | tb_writer = SummaryWriter() 25 | 26 | train_images_path, train_images_label, _, _ = read_split_data(root='fakeface/train') 27 | val_images_path, val_images_label, _, _ = read_split_data(root='fakeface/test') 28 | 29 | data_transform = { 30 | "train": transforms.Compose([transforms.RandomResizedCrop(224), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 34 | "val": transforms.Compose([transforms.Resize(224), 35 | #transforms.CenterCrop(224), 36 | transforms.ToTensor(), 37 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])} 38 | 39 | 40 | train_dataset = MyDataSet(images_path=train_images_path, 41 | images_class=train_images_label, 42 | transform=data_transform["train"]) 43 | 44 | 45 | val_dataset = MyDataSet(images_path=val_images_path, 46 | images_class=val_images_label, 47 | transform=data_transform["val"]) 48 | 49 | batch_size = args.batch_size 50 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 51 | print('Using {} dataloader workers every process'.format(nw)) 52 | train_loader = torch.utils.data.DataLoader(train_dataset, 53 | batch_size=batch_size, 54 | shuffle=True, 55 | pin_memory=True, 56 | num_workers=nw, 57 | collate_fn=train_dataset.collate_fn) 58 | 59 | val_loader = torch.utils.data.DataLoader(val_dataset, 60 | batch_size=batch_size, 61 | shuffle=False, 62 | pin_memory=True, 63 | num_workers=nw, 64 | collate_fn=val_dataset.collate_fn) 65 | 66 | 67 | model = create_model().to(device) 68 | #loss = create_FocalLoss(alpha=0.75, gamma=2).to(device) 69 | 70 | lr = 0.0005 71 | pg = [p for p in model.parameters() if p.requires_grad] 72 | optimizer = optim.Adam(pg, lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=5E-5) 73 | #optimizer = optim.SGD(pg, lr=lr, momentum=0.9, weight_decay=5E-5) 74 | # Scheduler https://arxiv.org/pdf/1812.01187.pdf 75 | lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine 76 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) 77 | 78 | for epoch in range(args.epochs): 79 | # train 80 | train_loss, train_acc, lr = train_one_epoch(model=model, 81 | optimizer=optimizer, 82 | data_loader=train_loader, 83 | device=device, 84 | epoch=epoch, 85 | lr=lr) 86 | 87 | scheduler.step() 88 | 89 | # validate 90 | val_loss, val_acc = evaluate(model=model, 91 | data_loader=val_loader, 92 | device=device, 93 | epoch=epoch) 94 | 95 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 96 | tb_writer.add_scalar(tags[0], train_loss, epoch) 97 | tb_writer.add_scalar(tags[1], train_acc, epoch) 98 | tb_writer.add_scalar(tags[2], val_loss, epoch) 99 | tb_writer.add_scalar(tags[3], val_acc, epoch) 100 | tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) 101 | 102 | torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch)) 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument('--num_classes', type=int, default=2) 108 | parser.add_argument('--epochs', type=int, default=10) 109 | parser.add_argument('--batch-size', type=int, default=8) 110 | #parser.add_argument('--lr', type=float, default=0.005) 111 | parser.add_argument('--lrf', type=float, default=0.01) 112 | 113 | 114 | 115 | # parser.add_argument('--data-path', type=str, 116 | # default=r"fakeface/train") 117 | # parser.add_argument('--data-test-path', type=str, 118 | # default=r"fakeface/test") 119 | parser.add_argument('--model-name', default='', help='create model name') 120 | 121 | 122 | 123 | 124 | 125 | opt = parser.parse_args() 126 | 127 | main(opt) 128 | 129 | 130 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | import random 6 | #from M2TR_model import create_FocalLoss 7 | import torch 8 | from tqdm import tqdm 9 | 10 | 11 | 12 | 13 | def read_split_data(root: str, val_rate: float = 0.2): #modify the val_rate to change the rate of validation 14 | random.seed(0) 15 | assert os.path.exists(root), "dataset root: {} does not exist.".format(root) 16 | 17 | 18 | classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] 19 | 20 | classes.sort() 21 | 22 | class_indices = dict((k, v) for v, k in enumerate(classes)) 23 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) 24 | with open('class_indices.json', 'w') as json_file: 25 | json_file.write(json_str) 26 | 27 | train_images_path = [] 28 | train_images_label = [] 29 | val_images_path = [] 30 | val_images_label = [] 31 | every_class_num = [] 32 | supported = [".jpg", ".JPG", ".png", ".PNG"] 33 | 34 | for cla in classes: 35 | cla_path = os.path.join(root, cla) 36 | 37 | images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) 38 | if os.path.splitext(i)[-1] in supported] 39 | 40 | image_class = class_indices[cla] 41 | 42 | every_class_num.append(len(images)) 43 | 44 | val_path = random.sample(images, k=int(len(images) * val_rate)) 45 | 46 | for img_path in images: 47 | if img_path in val_path: 48 | val_images_path.append(img_path) 49 | val_images_label.append(image_class) 50 | else: 51 | train_images_path.append(img_path) 52 | train_images_label.append(image_class) 53 | 54 | print("{} images were found in the dataset.".format(sum(every_class_num))) 55 | print("{} images for training.".format(len(train_images_path))) 56 | print("{} images for validation.".format(len(val_images_path))) 57 | 58 | 59 | return train_images_path, train_images_label, val_images_path, val_images_label 60 | 61 | 62 | 63 | 64 | def center_loss(x, label, center): 65 | b, c, h, w = x.size() 66 | losspos = 0 67 | posnum = 0 68 | 69 | lossneg = 0 70 | negnum = 0 71 | for i in range(b): 72 | if label[i] == 0: 73 | losspos = losspos + torch.sqrt(torch.sum((x[i] - center)**2)) 74 | posnum = posnum + 1 75 | elif label[i] == 1: 76 | lossneg = lossneg + torch.sqrt(torch.sum((x[i] - center)**2)) 77 | negnum = negnum + 1 78 | if posnum == 0: 79 | loss = -lossneg/negnum 80 | elif negnum == 0: 81 | loss = losspos/posnum 82 | else: 83 | loss = losspos/posnum - lossneg/negnum 84 | return loss 85 | 86 | 87 | #train 88 | def train_one_epoch(model, optimizer, data_loader, device, epoch, lr): 89 | model.train() 90 | loss_function = torch.nn.CrossEntropyLoss() 91 | accu_loss = torch.zeros(1).to(device) 92 | accu_num = torch.zeros(1).to(device) 93 | optimizer.zero_grad() 94 | 95 | sample_num = 0 96 | data_loader = tqdm(data_loader) 97 | 98 | if epoch%1 == 0: 99 | lr = lr * (0.5) ** (epoch/5) 100 | 101 | for step, data in enumerate(data_loader): 102 | images, labels = data 103 | 104 | sample_num += images.shape[0] 105 | 106 | pred = model(images.to(device)) 107 | 108 | 109 | 110 | pred_classes = torch.max(pred, dim=1)[1] 111 | 112 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 113 | 114 | 115 | 116 | 117 | #labels_t = torch.sparse.torch.eye(2).index_select(0, labels) 118 | loss = loss_function(pred, labels.to(device)) 119 | 120 | 121 | loss.backward() 122 | accu_loss += loss.detach() 123 | 124 | data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 125 | accu_loss.item() / (step + 1), 126 | accu_num.item() / sample_num) 127 | 128 | if not torch.isfinite(loss): 129 | print('WARNING: non-finite loss, ending training ', loss) 130 | sys.exit(1) 131 | 132 | optimizer.step() 133 | optimizer.zero_grad() 134 | 135 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num, lr 136 | 137 | #validation 138 | @torch.no_grad() 139 | def evaluate(model, data_loader, device, epoch): 140 | loss_function = torch.nn.CrossEntropyLoss() 141 | 142 | model.eval() 143 | 144 | accu_num = torch.zeros(1).to(device) 145 | accu_loss = torch.zeros(1).to(device) 146 | 147 | sample_num = 0 148 | data_loader = tqdm(data_loader) 149 | for step, data in enumerate(data_loader): 150 | images, labels = data 151 | 152 | sample_num += images.shape[0] 153 | 154 | pred = model(images.to(device)) 155 | pred_classes = torch.max(pred, dim=1)[1] 156 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 157 | #labels = torch.sparse.torch.eye(2).index_select(0, labels) 158 | loss = loss_function(pred, labels.to(device)) 159 | accu_loss += loss 160 | 161 | data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 162 | accu_loss.item() / (step + 1), 163 | accu_num.item() / sample_num) 164 | 165 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 166 | --------------------------------------------------------------------------------