├── README.md ├── network ├── __pycache__ │ ├── classifier.cpython-35.pyc │ ├── classifier.cpython-36.pyc │ ├── transform.cpython-35.pyc │ └── transform.cpython-36.pyc ├── classifier.py └── transform.py ├── output └── Mesonet │ └── best.pkl ├── test.py ├── train_Meso.py └── train_MesoInception.py /README.md: -------------------------------------------------------------------------------- 1 | # MesoNet-Pytorch 2 | ------------------------------------------- 3 | The is a personal Reimplemention of MesoNet[1] using Pytorch. If you make use of this work, please cite the paper accordingly. 4 | 5 | For the original version of this work using Keras, please see: [DariusAf/MesoNet](https://github.com/DariusAf/MesoNet) 6 | 7 | ## Install & Requirements 8 | The code has been test on pytorch 1.3.1, torchvision 0.4.2 and python 3.6.9, please refer to `requirements.txt` for more details. 9 | 10 | **To install the python packges** 11 | 12 | `python -m pip install -r requiremnets.txt` 13 | 14 | ## Usage 15 | **To train the normal MesoNet** 16 | 17 | `python train_Meso.py -n 'Mesonet' -tp './data/train' -vp './data/val' -bz 64 -e 100 -mn 'meso4.pkl'` 18 | 19 | **To train the MesoInceptionNet** 20 | 21 | `python train_MesoInception.py -n 'MesoInception' -tp './data/train' -vp './data/val' -bz 64 -e 100 -mn 'mesoinception.pkl'` 22 | 23 | 24 | If you continue training a pretrained model, you should use `--continue_train True -mp ./pretrained_models/model.pkl` 25 | 26 | **To test the trained Model** 27 | 28 | `python test.py -bz 64 -tp './data/test' -mp './Mesonet/best.pkl'` 29 | 30 | ## License 31 | The provided implementation is strictly for academic purposes only. Should you be interested in using our technology for any commercial use, please feel free to contact us. 32 | 33 | ## Reference 34 | [1] Afchar, D., Nozick, V., Yamagishi, J., & Echizen, I. (2018, September). MesoNet: a Compact Facial Video Forgery Detection Network. In IEEE Workshop on Information Forensics and Security, WIFS 2018. 35 | -------------------------------------------------------------------------------- /network/__pycache__/classifier.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongguLiu/MesoNet-Pytorch/205f8c3b172e66064f3832a906c64b40676365e9/network/__pycache__/classifier.cpython-35.pyc -------------------------------------------------------------------------------- /network/__pycache__/classifier.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongguLiu/MesoNet-Pytorch/205f8c3b172e66064f3832a906c64b40676365e9/network/__pycache__/classifier.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/transform.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongguLiu/MesoNet-Pytorch/205f8c3b172e66064f3832a906c64b40676365e9/network/__pycache__/transform.cpython-35.pyc -------------------------------------------------------------------------------- /network/__pycache__/transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongguLiu/MesoNet-Pytorch/205f8c3b172e66064f3832a906c64b40676365e9/network/__pycache__/transform.cpython-36.pyc -------------------------------------------------------------------------------- /network/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | import torchvision 10 | 11 | class Meso4(nn.Module): 12 | """ 13 | Pytorch Implemention of Meso4 14 | Autor: Honggu Liu 15 | Date: July 4, 2019 16 | """ 17 | def __init__(self, num_classes=2): 18 | super(Meso4, self).__init__() 19 | self.num_classes = num_classes 20 | self.conv1 = nn.Conv2d(3, 8, 3, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(8) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.leakyrelu = nn.LeakyReLU(0.1) 24 | 25 | self.conv2 = nn.Conv2d(8, 8, 5, padding=2, bias=False) 26 | self.bn2 = nn.BatchNorm2d(16) 27 | self.conv3 = nn.Conv2d(8, 16, 5, padding=2, bias=False) 28 | self.conv4 = nn.Conv2d(16, 16, 5, padding=2, bias=False) 29 | self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2)) 30 | self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4)) 31 | #flatten: x = x.view(x.size(0), -1) 32 | self.dropout = nn.Dropout2d(0.5) 33 | self.fc1 = nn.Linear(16*8*8, 16) 34 | self.fc2 = nn.Linear(16, num_classes) 35 | 36 | def forward(self, input): 37 | x = self.conv1(input) #(8, 256, 256) 38 | x = self.relu(x) 39 | x = self.bn1(x) 40 | x = self.maxpooling1(x) #(8, 128, 128) 41 | 42 | x = self.conv2(x) #(8, 128, 128) 43 | x = self.relu(x) 44 | x = self.bn1(x) 45 | x = self.maxpooling1(x) #(8, 64, 64) 46 | 47 | x = self.conv3(x) #(16, 64, 64) 48 | x = self.relu(x) 49 | x = self.bn2(x) 50 | x = self.maxpooling1(x) #(16, 32, 32) 51 | 52 | x = self.conv4(x) #(16, 32, 32) 53 | x = self.relu(x) 54 | x = self.bn2(x) 55 | x = self.maxpooling2(x) #(16, 8, 8) 56 | 57 | x = x.view(x.size(0), -1) #(Batch, 16*8*8) 58 | x = self.dropout(x) 59 | x = self.fc1(x) #(Batch, 16) 60 | x = self.leakyrelu(x) 61 | x = self.dropout(x) 62 | x = self.fc2(x) 63 | 64 | return x 65 | 66 | 67 | class MesoInception4(nn.Module): 68 | """ 69 | Pytorch Implemention of MesoInception4 70 | Author: Honggu Liu 71 | Date: July 7, 2019 72 | """ 73 | def __init__(self, num_classes=2): 74 | super(MesoInception4, self).__init__() 75 | self.num_classes = num_classes 76 | #InceptionLayer1 77 | self.Incption1_conv1 = nn.Conv2d(3, 1, 1, padding=0, bias=False) 78 | self.Incption1_conv2_1 = nn.Conv2d(3, 4, 1, padding=0, bias=False) 79 | self.Incption1_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False) 80 | self.Incption1_conv3_1 = nn.Conv2d(3, 4, 1, padding=0, bias=False) 81 | self.Incption1_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False) 82 | self.Incption1_conv4_1 = nn.Conv2d(3, 2, 1, padding=0, bias=False) 83 | self.Incption1_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False) 84 | self.Incption1_bn = nn.BatchNorm2d(11) 85 | 86 | 87 | #InceptionLayer2 88 | self.Incption2_conv1 = nn.Conv2d(11, 2, 1, padding=0, bias=False) 89 | self.Incption2_conv2_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False) 90 | self.Incption2_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False) 91 | self.Incption2_conv3_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False) 92 | self.Incption2_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False) 93 | self.Incption2_conv4_1 = nn.Conv2d(11, 2, 1, padding=0, bias=False) 94 | self.Incption2_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False) 95 | self.Incption2_bn = nn.BatchNorm2d(12) 96 | 97 | #Normal Layer 98 | self.conv1 = nn.Conv2d(12, 16, 5, padding=2, bias=False) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.leakyrelu = nn.LeakyReLU(0.1) 101 | self.bn1 = nn.BatchNorm2d(16) 102 | self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2)) 103 | 104 | self.conv2 = nn.Conv2d(16, 16, 5, padding=2, bias=False) 105 | self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4)) 106 | 107 | self.dropout = nn.Dropout2d(0.5) 108 | self.fc1 = nn.Linear(16*8*8, 16) 109 | self.fc2 = nn.Linear(16, num_classes) 110 | 111 | 112 | #InceptionLayer 113 | def InceptionLayer1(self, input): 114 | x1 = self.Incption1_conv1(input) 115 | x2 = self.Incption1_conv2_1(input) 116 | x2 = self.Incption1_conv2_2(x2) 117 | x3 = self.Incption1_conv3_1(input) 118 | x3 = self.Incption1_conv3_2(x3) 119 | x4 = self.Incption1_conv4_1(input) 120 | x4 = self.Incption1_conv4_2(x4) 121 | y = torch.cat((x1, x2, x3, x4), 1) 122 | y = self.Incption1_bn(y) 123 | y = self.maxpooling1(y) 124 | 125 | return y 126 | 127 | def InceptionLayer2(self, input): 128 | x1 = self.Incption2_conv1(input) 129 | x2 = self.Incption2_conv2_1(input) 130 | x2 = self.Incption2_conv2_2(x2) 131 | x3 = self.Incption2_conv3_1(input) 132 | x3 = self.Incption2_conv3_2(x3) 133 | x4 = self.Incption2_conv4_1(input) 134 | x4 = self.Incption2_conv4_2(x4) 135 | y = torch.cat((x1, x2, x3, x4), 1) 136 | y = self.Incption2_bn(y) 137 | y = self.maxpooling1(y) 138 | 139 | return y 140 | 141 | def forward(self, input): 142 | x = self.InceptionLayer1(input) #(Batch, 11, 128, 128) 143 | x = self.InceptionLayer2(x) #(Batch, 12, 64, 64) 144 | 145 | x = self.conv1(x) #(Batch, 16, 64 ,64) 146 | x = self.relu(x) 147 | x = self.bn1(x) 148 | x = self.maxpooling1(x) #(Batch, 16, 32, 32) 149 | 150 | x = self.conv2(x) #(Batch, 16, 32, 32) 151 | x = self.relu(x) 152 | x = self.bn1(x) 153 | x = self.maxpooling2(x) #(Batch, 16, 8, 8) 154 | 155 | x = x.view(x.size(0), -1) #(Batch, 16*8*8) 156 | x = self.dropout(x) 157 | x = self.fc1(x) #(Batch, 16) 158 | x = self.leakyrelu(x) 159 | x = self.dropout(x) 160 | x = self.fc2(x) 161 | 162 | return x 163 | 164 | -------------------------------------------------------------------------------- /network/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Author: Honggu Liu 4 | """ 5 | from torchvision import transforms 6 | 7 | mesonet_data_transforms = { 8 | 'train': transforms.Compose([ 9 | transforms.Resize((256, 256)), 10 | transforms.ToTensor(), 11 | transforms.Normalize([0.5]*3, [0.5]*3) 12 | ]), 13 | 'val': transforms.Compose([ 14 | transforms.Resize((256, 256)), 15 | transforms.ToTensor(), 16 | transforms.Normalize([0.5] * 3, [0.5] * 3) 17 | ]), 18 | 'test': transforms.Compose([ 19 | transforms.Resize((256, 256)), 20 | transforms.ToTensor(), 21 | transforms.Normalize([0.5] * 3, [0.5] * 3) 22 | ]), 23 | } 24 | -------------------------------------------------------------------------------- /output/Mesonet/best.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongguLiu/MesoNet-Pytorch/205f8c3b172e66064f3832a906c64b40676365e9/output/Mesonet/best.pkl -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | import argparse 8 | import os 9 | import cv2 10 | from torchvision import datasets, models, transforms 11 | from network.classifier import * 12 | from network.transform import mesonet_data_transforms 13 | def main(): 14 | args = parse.parse_args() 15 | test_path = args.test_path 16 | batch_size = args.batch_size 17 | model_path = args.model_path 18 | torch.backends.cudnn.benchmark=True 19 | test_dataset = torchvision.datasets.ImageFolder(test_path, transform=mesonet_data_transforms['val']) 20 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8) 21 | test_dataset_size = len(test_dataset) 22 | corrects = 0 23 | acc = 0 24 | model = Meso4() 25 | model.load_state_dict(torch.load(model_path)) 26 | if isinstance(model, torch.nn.DataParallel): 27 | model = model.module 28 | model = model.cuda() 29 | model.eval() 30 | with torch.no_grad(): 31 | for (image, labels) in test_loader: 32 | image = image.cuda() 33 | labels = labels.cuda() 34 | outputs = model(image) 35 | _, preds = torch.max(outputs.data, 1) 36 | corrects += torch.sum(preds == labels.data).to(torch.float32) 37 | print('Iteration Acc {:.4f}'.format(torch.sum(preds == labels.data).to(torch.float32)/batch_size)) 38 | acc = corrects / test_dataset_size 39 | print('Test Acc: {:.4f}'.format(acc)) 40 | 41 | 42 | 43 | if __name__ == '__main__': 44 | parse = argparse.ArgumentParser( 45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 46 | parse.add_argument('--batch_size', '-bz', type=int, default=32) 47 | parse.add_argument('--test_path', '-tp', type=str, default='./deepfake_database/test') 48 | parse.add_argument('--model_path', '-mp', type=str, default='./output/Mesonet/best.pkl') 49 | main() -------------------------------------------------------------------------------- /train_Meso.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | import argparse 8 | import os 9 | import cv2 10 | from torchvision import datasets, models, transforms 11 | from network.classifier import * 12 | from network.transform import mesonet_data_transforms 13 | def main(): 14 | args = parse.parse_args() 15 | name = args.name 16 | train_path = args.train_path 17 | val_path = args.val_path 18 | continue_train = args.continue_train 19 | epoches = args.epoches 20 | batch_size = args.batch_size 21 | model_name = args.model_name 22 | model_path = args.model_path 23 | output_path = os.path.join('./output', name) 24 | if not os.path.exists(output_path): 25 | os.mkdir(output_path) 26 | torch.backends.cudnn.benchmark=True 27 | 28 | #creat train and val dataloader 29 | train_dataset = torchvision.datasets.ImageFolder(train_path, transform=mesonet_data_transforms['train']) 30 | val_dataset = torchvision.datasets.ImageFolder(val_path, transform=mesonet_data_transforms['val']) 31 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8) 32 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8) 33 | train_dataset_size = len(train_dataset) 34 | val_dataset_size = len(val_dataset) 35 | 36 | 37 | #Creat the model 38 | model = Meso4() 39 | if continue_train: 40 | model.load_state_dict(torch.load(model_path)) 41 | model = model.cuda() 42 | criterion = nn.CrossEntropyLoss() 43 | #optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001) 44 | optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08) 45 | scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) 46 | 47 | #Train the model using multiple GPUs 48 | #model = nn.DataParallel(model) 49 | 50 | best_model_wts = model.state_dict() 51 | best_acc = 0.0 52 | iteration = 0 53 | for epoch in range(epoches): 54 | print('Epoch {}/{}'.format(epoch+1, epoches)) 55 | print('-'*10) 56 | model=model.train() 57 | train_loss = 0.0 58 | train_corrects = 0.0 59 | val_loss = 0.0 60 | val_corrects = 0.0 61 | for (image, labels) in train_loader: 62 | iter_loss = 0.0 63 | iter_corrects = 0.0 64 | image = image.cuda() 65 | labels = labels.cuda() 66 | optimizer.zero_grad() 67 | outputs = model(image) 68 | _, preds = torch.max(outputs.data, 1) 69 | loss = criterion(outputs, labels) 70 | loss.backward() 71 | optimizer.step() 72 | iter_loss = loss.data.item() 73 | train_loss += iter_loss 74 | iter_corrects = torch.sum(preds == labels.data).to(torch.float32) 75 | train_corrects += iter_corrects 76 | iteration += 1 77 | if not (iteration % 20): 78 | print('iteration {} train loss: {:.4f} Acc: {:.4f}'.format(iteration, iter_loss / batch_size, iter_corrects / batch_size)) 79 | epoch_loss = train_loss / train_dataset_size 80 | epoch_acc = train_corrects / train_dataset_size 81 | print('epoch train loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) 82 | 83 | model.eval() 84 | with torch.no_grad(): 85 | for (image, labels) in val_loader: 86 | image = image.cuda() 87 | labels = labels.cuda() 88 | outputs = model(image) 89 | _, preds = torch.max(outputs.data, 1) 90 | loss = criterion(outputs, labels) 91 | val_loss += loss.data.item() 92 | val_corrects += torch.sum(preds == labels.data).to(torch.float32) 93 | epoch_loss = val_loss / val_dataset_size 94 | epoch_acc = val_corrects / val_dataset_size 95 | print('epoch val loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) 96 | if epoch_acc > best_acc: 97 | best_acc = epoch_acc 98 | best_model_wts = model.state_dict() 99 | scheduler.step() 100 | if not (epoch % 10): 101 | #Save the model trained with multiple gpu 102 | #torch.save(model.module.state_dict(), os.path.join(output_path, str(epoch) + '_' + model_name)) 103 | torch.save(model.state_dict(), os.path.join(output_path, str(epoch) + '_' + model_name)) 104 | print('Best val Acc: {:.4f}'.format(best_acc)) 105 | model.load_state_dict(best_model_wts) 106 | #torch.save(model.module.state_dict(), os.path.join(output_path, "best.pkl")) 107 | torch.save(model.state_dict(), os.path.join(output_path, "best.pkl")) 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | parse = argparse.ArgumentParser( 113 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 114 | parse.add_argument('--name', '-n', type=str, default='Mesonet') 115 | parse.add_argument('--train_path', '-tp' , type=str, default = './deepfake_database/train') 116 | parse.add_argument('--val_path', '-vp' , type=str, default = './deepfake_database/val') 117 | parse.add_argument('--batch_size', '-bz', type=int, default=64) 118 | parse.add_argument('--epoches', '-e', type=int, default='50') 119 | parse.add_argument('--model_name', '-mn', type=str, default='meso4.pkl') 120 | parse.add_argument('--continue_train', type=bool, default=False) 121 | parse.add_argument('--model_path', '-mp', type=str, default='./output/Mesonet/best.pkl') 122 | main() -------------------------------------------------------------------------------- /train_MesoInception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | import argparse 8 | import os 9 | import cv2 10 | from torchvision import datasets, models, transforms 11 | from network.classifier import * 12 | from network.transform import mesonet_data_transforms 13 | def main(): 14 | args = parse.parse_args() 15 | name = args.name 16 | train_path = args.train_path 17 | val_path = args.val_path 18 | continue_train = args.continue_train 19 | epoches = args.epoches 20 | batch_size = args.batch_size 21 | model_name = args.model_name 22 | model_path = args.model_path 23 | output_path = os.path.join('./output', name) 24 | if not os.path.exists(output_path): 25 | os.mkdir(output_path) 26 | torch.backends.cudnn.benchmark=True 27 | 28 | #creat train and val dataloader 29 | train_dataset = torchvision.datasets.ImageFolder(train_path, transform=mesonet_data_transforms['train']) 30 | val_dataset = torchvision.datasets.ImageFolder(val_path, transform=mesonet_data_transforms['val']) 31 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8) 32 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8) 33 | train_dataset_size = len(train_dataset) 34 | val_dataset_size = len(val_dataset) 35 | 36 | 37 | #Creat the model 38 | model = MesoInception4() 39 | if continue_train: 40 | model.load_state_dict(torch.load(model_path)) 41 | model = model.cuda() 42 | criterion = nn.CrossEntropyLoss() 43 | #optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001) 44 | optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08) 45 | scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) 46 | 47 | #Train the model using multiple GPUs 48 | #model = nn.DataParallel(model) 49 | 50 | best_model_wts = model.state_dict() 51 | best_acc = 0.0 52 | iteration = 0 53 | for epoch in range(epoches): 54 | print('Epoch {}/{}'.format(epoch+1, epoches)) 55 | print('-'*10) 56 | model=model.train() 57 | train_loss = 0.0 58 | train_corrects = 0.0 59 | val_loss = 0.0 60 | val_corrects = 0.0 61 | for (image, labels) in train_loader: 62 | iter_loss = 0.0 63 | iter_corrects = 0.0 64 | image = image.cuda() 65 | labels = labels.cuda() 66 | optimizer.zero_grad() 67 | outputs = model(image) 68 | _, preds = torch.max(outputs.data, 1) 69 | loss = criterion(outputs, labels) 70 | loss.backward() 71 | optimizer.step() 72 | iter_loss = loss.data.item() 73 | train_loss += iter_loss 74 | iter_corrects = torch.sum(preds == labels.data).to(torch.float32) 75 | train_corrects += iter_corrects 76 | iteration += 1 77 | if not (iteration % 20): 78 | print('iteration {} train loss: {:.4f} Acc: {:.4f}'.format(iteration, iter_loss / batch_size, iter_corrects / batch_size)) 79 | epoch_loss = train_loss / train_dataset_size 80 | epoch_acc = train_corrects / train_dataset_size 81 | print('epoch train loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) 82 | 83 | model.eval() 84 | with torch.no_grad(): 85 | for (image, labels) in val_loader: 86 | image = image.cuda() 87 | labels = labels.cuda() 88 | outputs = model(image) 89 | _, preds = torch.max(outputs.data, 1) 90 | loss = criterion(outputs, labels) 91 | val_loss += loss.data.item() 92 | val_corrects += torch.sum(preds == labels.data).to(torch.float32) 93 | epoch_loss = val_loss / val_dataset_size 94 | epoch_acc = val_corrects / val_dataset_size 95 | print('epoch val loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) 96 | if epoch_acc > best_acc: 97 | best_acc = epoch_acc 98 | best_model_wts = model.state_dict() 99 | scheduler.step() 100 | if not (epoch % 10): 101 | #torch.save(model.module.state_dict(), os.path.join(output_path, str(epoch) + '_' + model_name)) 102 | torch.save(model.state_dict(), os.path.join(output_path, str(epoch) + '_' + model_name)) 103 | print('Best val Acc: {:.4f}'.format(best_acc)) 104 | model.load_state_dict(best_model_wts) 105 | #torch.save(model.module.state_dict(), os.path.join(output_path, "best.pkl")) 106 | torch.save(model.state_dict(), os.path.join(output_path, "best.pkl")) 107 | 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | parse = argparse.ArgumentParser( 113 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 114 | parse.add_argument('--name', '-n', type=str, default='Mesonet') 115 | parse.add_argument('--train_path', '-tp' , type=str, default = './deepfake_database/train') 116 | parse.add_argument('--val_path', '-vp' , type=str, default = './deepfake_database/val') 117 | parse.add_argument('--batch_size', '-bz', type=int, default=64) 118 | parse.add_argument('--epoches', '-e', type=int, default='50') 119 | parse.add_argument('--model_name', '-mn', type=str, default='meso4.pkl') 120 | parse.add_argument('--continue_train', type=bool, default=False) 121 | parse.add_argument('--model_path', '-mp', type=str, default='./output/Mesonet/best.pkl') 122 | main() --------------------------------------------------------------------------------