├── README.md ├── dataset ├── classes.txt ├── fake │ └── example_fake.png └── real │ └── example_real.png ├── detection.py ├── helper_codes ├── transform.py └── xception.py ├── image_prediction.py ├── main.py ├── models └── x-model23.p ├── requirements.txt └── train_dateset.py /README.md: -------------------------------------------------------------------------------- 1 | # deepfake-detection-with-xception 2 | 3 | Steps: 4 | - Grab the required packages from requirements.txt using pip 5 | 6 | prepare and train Dataset: 7 | - We have used the Kaggle Deepfake Challange dataset, Link: https://www.kaggle.com/c/deepfake-detection-challenge/data 8 | 9 | - Download the dataset and from `train_sample_videos` folder, extract faces from the videos. Put them in the corresponding folders inside `dataset`. Classes are predefined already. 10 | 11 | - Run `train_dataset.py` to train and generate models. 12 | 13 | ```bash 14 | python3.py train_dataset.py dataset/ classes.txt result/ 15 | ``` 16 | 17 | Predefined settings: 18 | ```bash 19 | Epoch: 10 / 30 [First/Final stage] 20 | Learning rate: 5e-3 / 5e-4 21 | Batch size: 32 / 64 22 | ``` 23 | 24 | - Then take the best model from examining the graph and run `app.py` to detect videos. It can take a video file or a youtube-dl supported video link as a input. Note that we've tested online links only with Youtube so your results may vary. 25 | 26 | 27 | Note: 28 | 29 | There's also a basic image predictor which takes a LOT less time compared to a video. Use `image_prediction.py` 30 | 31 | ```bash 32 | python3 image_prediction.py path_to_model.p classes.txt input_image.jpg 33 | ``` 34 | 35 | 36 | Credits: 37 | ``` 38 | Francois Chollet 39 | Xception: Deep Learning with Depthwise Separable Convolutions 40 | https://arxiv.org/pdf/1610.02357.pdf 41 | ``` 42 | -------------------------------------------------------------------------------- /dataset/classes.txt: -------------------------------------------------------------------------------- 1 | fake 2 | real -------------------------------------------------------------------------------- /dataset/fake/example_fake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i3p9/deepfake-detection-with-xception/810313758d166105d18d3ec13b5c88673014a332/dataset/fake/example_fake.png -------------------------------------------------------------------------------- /dataset/real/example_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i3p9/deepfake-detection-with-xception/810313758d166105d18d3ec13b5c88673014a332/dataset/real/example_real.png -------------------------------------------------------------------------------- /detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from os.path import join 4 | import cv2 5 | import dlib 6 | import torch 7 | import torch.nn as nn 8 | from PIL import Image as pil_image 9 | from tqdm import tqdm 10 | 11 | from helper_codes.transform import transform_xception 12 | 13 | 14 | def get_boundingbox(face, width, height, scale=1.3, minsize=None): 15 | x1 = face.left() 16 | y1 = face.top() 17 | x2 = face.right() 18 | y2 = face.bottom() 19 | size_bb = int(max(x2 - x1, y2 - y1) * scale) 20 | if minsize: 21 | if size_bb < minsize: 22 | size_bb = minsize 23 | center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 24 | 25 | # Check for out of bounds, x-y top left corner 26 | x1 = max(int(center_x - size_bb // 2), 0) 27 | y1 = max(int(center_y - size_bb // 2), 0) 28 | size_bb = min(width - x1, size_bb) 29 | size_bb = min(height - y1, size_bb) 30 | 31 | return x1, y1, size_bb 32 | 33 | 34 | def preprocess_image(image, cuda=False): 35 | # Revert from BGR 36 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 37 | # Preprocess using the preprocessing function used during training and 38 | preprocess = transform_xception['test'] 39 | preprocessed_image = preprocess(pil_image.fromarray(image)) 40 | # Add first dimension as the network expects a batch 41 | preprocessed_image = preprocessed_image.unsqueeze(0) 42 | if cuda: 43 | preprocessed_image = preprocessed_image.cuda() 44 | return preprocessed_image 45 | 46 | 47 | def predict_with_model(image, model, post_function=nn.Softmax(dim=1), 48 | cuda=False): 49 | # Preprocess 50 | preprocessed_image = preprocess_image(image, cuda) 51 | 52 | # Model prediction 53 | output = model(preprocessed_image) 54 | output = post_function(output) 55 | 56 | # Cast to desired 57 | _, prediction = torch.max(output, 1) 58 | prediction = float(prediction.cpu().numpy()) 59 | 60 | return int(prediction), output 61 | 62 | 63 | def test_full_image_network(video_path, model_path, output_path, fast, 64 | start_frame=0, end_frame=None, cuda=False): 65 | print('Starting: {}'.format(video_path)) 66 | 67 | # Read and write 68 | reader = cv2.VideoCapture(video_path) 69 | 70 | video_fn = video_path.split('/')[-1].split('.')[0]+'.avi' 71 | os.makedirs(output_path, exist_ok=True) 72 | fourcc = cv2.VideoWriter_fourcc(*'MJPG') 73 | fps = reader.get(cv2.CAP_PROP_FPS) 74 | num_frames = int(reader.get(cv2.CAP_PROP_FRAME_COUNT)) 75 | writer = None 76 | 77 | # Face detector 78 | face_detector = dlib.get_frontal_face_detector() 79 | 80 | # Load model 81 | model, *_ = model_selection(modelname='xception', num_out_classes=2) 82 | if model_path is not None: 83 | model = torch.load(model_path, map_location="cuda" if torch.cuda.is_available() else "cpu") 84 | print('Model found in {}'.format(model_path)) 85 | else: 86 | print('No model found, initializing random model.') 87 | if cuda: 88 | model = model.cuda() 89 | 90 | # Text variables 91 | font_face = cv2.FONT_HERSHEY_SIMPLEX 92 | thickness = 2 93 | font_scale = 1 94 | 95 | # Fake frames number 96 | ff = 0 97 | ffn = 0 98 | 99 | # Frame numbers and length of output video 100 | frame_num = 0 101 | assert start_frame < num_frames - 1 102 | end_frame = end_frame if end_frame else num_frames 103 | pbar = tqdm(total=end_frame-start_frame) 104 | 105 | while reader.isOpened(): 106 | _, image = reader.read() 107 | if image is None: 108 | break 109 | if fast: 110 | frame_num += 10 111 | pbar.update(10) 112 | else: 113 | frame_num+= 1 114 | pbar.update(1) 115 | 116 | 117 | if frame_num < start_frame: 118 | continue 119 | 120 | height, width = image.shape[:2] 121 | 122 | if writer is None: 123 | writer = cv2.VideoWriter(join(output_path, video_fn), fourcc, fps, 124 | (height, width)[::-1]) 125 | 126 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 127 | faces = face_detector(gray, 1) 128 | if len(faces): 129 | # If multiple faces, take the biggest one 130 | face = faces[0] 131 | 132 | x, y, size = get_boundingbox(face, width, height) 133 | cropped_face = image[y:y+size, x:x+size] 134 | 135 | #prediction using our model 136 | prediction, output = predict_with_model(cropped_face, model, 137 | cuda=cuda) 138 | 139 | if prediction == 1: 140 | ff += 1 141 | ffn +=1 142 | x = face.left() 143 | y = face.top() 144 | w = face.right() - x 145 | h = face.bottom() - y 146 | label = 'fake' if prediction == 1 else 'real' 147 | color = (0, 255, 0) if prediction == 0 else (0, 0, 255) 148 | output_list = ['{0:.2f}'.format(float(x)) for x in 149 | output.detach().cpu().numpy()[0]] 150 | cv2.putText(image, str(output_list)+'=>'+label, (x, y+h+30), 151 | font_face, font_scale, 152 | color, thickness, 2) 153 | # draw box over face 154 | cv2.rectangle(image, (x, y), (x + w, y + h), color, 2) 155 | 156 | if frame_num >= end_frame: 157 | break 158 | 159 | writer.write(image) 160 | 161 | pbar.close() 162 | p = ff / float(ffn) * 100; 163 | 164 | if writer is not None: 165 | out = {} 166 | writer.release() 167 | out["score"] = p 168 | out["file"] = video_fn 169 | return out 170 | else: 171 | print('Input video file was empty') 172 | -------------------------------------------------------------------------------- /helper_codes/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | transform_xception = { 4 | 'train': transforms.Compose([ 5 | transforms.Resize((299, 299)), 6 | transforms.ToTensor(), 7 | transforms.Normalize([0.5]*3, [0.5]*3) 8 | ]), 9 | 'val': transforms.Compose([ 10 | transforms.Resize((299, 299)), 11 | transforms.ToTensor(), 12 | transforms.Normalize([0.5] * 3, [0.5] * 3) 13 | ]), 14 | 'test': transforms.Compose([ 15 | transforms.Resize((299, 299)), 16 | transforms.ToTensor(), 17 | transforms.Normalize([0.5] * 3, [0.5] * 3) 18 | ]), 19 | } 20 | -------------------------------------------------------------------------------- /helper_codes/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 3 | 4 | @author: tstandley 5 | Adapted by cadene 6 | 7 | Creates an Xception Model as defined in: 8 | 9 | Francois Chollet 10 | Xception: Deep Learning with Depthwise Separable Convolutions 11 | https://arxiv.org/pdf/1610.02357.pdf 12 | 13 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 14 | 15 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 16 | 17 | REMEMBER to set your image size to 3x299x299 for both test and validation 18 | 19 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 20 | std=[0.5, 0.5, 0.5]) 21 | 22 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 23 | """ 24 | import math 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | import torch.utils.model_zoo as model_zoo 29 | from torch.nn import init 30 | 31 | pretrained_settings = { 32 | 'xception': { 33 | 'imagenet': { 34 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth', 35 | 'input_space': 'RGB', 36 | 'input_size': [3, 299, 299], 37 | 'input_range': [0, 1], 38 | 'mean': [0.5, 0.5, 0.5], 39 | 'std': [0.5, 0.5, 0.5], 40 | 'num_classes': 1000, 41 | 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 42 | } 43 | } 44 | } 45 | 46 | 47 | class SeparableConv2d(nn.Module): 48 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 49 | super(SeparableConv2d,self).__init__() 50 | 51 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 52 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 53 | 54 | def forward(self,x): 55 | x = self.conv1(x) 56 | x = self.pointwise(x) 57 | return x 58 | 59 | 60 | class Block(nn.Module): 61 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): 62 | super(Block, self).__init__() 63 | 64 | if out_filters != in_filters or strides!=1: 65 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 66 | self.skipbn = nn.BatchNorm2d(out_filters) 67 | else: 68 | self.skip=None 69 | 70 | self.relu = nn.ReLU(inplace=True) 71 | rep=[] 72 | 73 | filters=in_filters 74 | if grow_first: 75 | rep.append(self.relu) 76 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 77 | rep.append(nn.BatchNorm2d(out_filters)) 78 | filters = out_filters 79 | 80 | for i in range(reps-1): 81 | rep.append(self.relu) 82 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) 83 | rep.append(nn.BatchNorm2d(filters)) 84 | 85 | if not grow_first: 86 | rep.append(self.relu) 87 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 88 | rep.append(nn.BatchNorm2d(out_filters)) 89 | 90 | if not start_with_relu: 91 | rep = rep[1:] 92 | else: 93 | rep[0] = nn.ReLU(inplace=False) 94 | 95 | if strides != 1: 96 | rep.append(nn.MaxPool2d(3,strides,1)) 97 | self.rep = nn.Sequential(*rep) 98 | 99 | def forward(self,inp): 100 | x = self.rep(inp) 101 | 102 | if self.skip is not None: 103 | skip = self.skip(inp) 104 | skip = self.skipbn(skip) 105 | else: 106 | skip = inp 107 | 108 | x+=skip 109 | return x 110 | 111 | 112 | class Xception(nn.Module): 113 | """ 114 | Xception optimized for the ImageNet dataset, as specified in 115 | https://arxiv.org/pdf/1610.02357.pdf 116 | """ 117 | def __init__(self, num_classes=1000): 118 | """ Constructor 119 | Args: 120 | num_classes: number of classes 121 | """ 122 | super(Xception, self).__init__() 123 | self.num_classes = num_classes 124 | 125 | self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) 126 | self.bn1 = nn.BatchNorm2d(32) 127 | self.relu = nn.ReLU(inplace=True) 128 | 129 | self.conv2 = nn.Conv2d(32,64,3,bias=False) 130 | self.bn2 = nn.BatchNorm2d(64) 131 | #do relu here 132 | 133 | self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) 134 | self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) 135 | self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) 136 | 137 | self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) 138 | self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) 139 | self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) 140 | self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) 141 | 142 | self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) 143 | self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) 144 | self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) 145 | self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) 146 | 147 | self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 148 | 149 | self.conv3 = SeparableConv2d(1024,1536,3,1,1) 150 | self.bn3 = nn.BatchNorm2d(1536) 151 | 152 | #do relu here 153 | self.conv4 = SeparableConv2d(1536,2048,3,1,1) 154 | self.bn4 = nn.BatchNorm2d(2048) 155 | 156 | self.fc = nn.Linear(2048, num_classes) 157 | 158 | # #------- init weights -------- 159 | # for m in self.modules(): 160 | # if isinstance(m, nn.Conv2d): 161 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 162 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 163 | # elif isinstance(m, nn.BatchNorm2d): 164 | # m.weight.data.fill_(1) 165 | # m.bias.data.zero_() 166 | # #----------------------------- 167 | 168 | def features(self, input): 169 | x = self.conv1(input) 170 | x = self.bn1(x) 171 | x = self.relu(x) 172 | 173 | x = self.conv2(x) 174 | x = self.bn2(x) 175 | x = self.relu(x) 176 | 177 | x = self.block1(x) 178 | x = self.block2(x) 179 | x = self.block3(x) 180 | x = self.block4(x) 181 | x = self.block5(x) 182 | x = self.block6(x) 183 | x = self.block7(x) 184 | x = self.block8(x) 185 | x = self.block9(x) 186 | x = self.block10(x) 187 | x = self.block11(x) 188 | x = self.block12(x) 189 | 190 | x = self.conv3(x) 191 | x = self.bn3(x) 192 | x = self.relu(x) 193 | 194 | x = self.conv4(x) 195 | x = self.bn4(x) 196 | return x 197 | 198 | def logits(self, features): 199 | x = self.relu(features) 200 | 201 | x = F.adaptive_avg_pool2d(x, (1, 1)) 202 | x = x.view(x.size(0), -1) 203 | x = self.last_linear(x) 204 | return x 205 | 206 | def forward(self, input): 207 | x = self.features(input) 208 | x = self.logits(x) 209 | return x 210 | 211 | 212 | def xception(num_classes=1000, pretrained='imagenet'): 213 | model = Xception(num_classes=num_classes) 214 | if pretrained: 215 | settings = pretrained_settings['xception'][pretrained] 216 | assert num_classes == settings['num_classes'], \ 217 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 218 | 219 | model = Xception(num_classes=num_classes) 220 | model.load_state_dict(model_zoo.load_url(settings['url'])) 221 | 222 | model.input_space = settings['input_space'] 223 | model.input_size = settings['input_size'] 224 | model.input_range = settings['input_range'] 225 | model.mean = settings['mean'] 226 | model.std = settings['std'] 227 | 228 | # TODO: ugly 229 | model.last_linear = model.fc 230 | del model.fc 231 | return model 232 | -------------------------------------------------------------------------------- /image_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from keras.applications.xception import preprocess_input 4 | from keras.preprocessing import image 5 | from keras.models import load_model 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('model') 9 | parser.add_argument('classes') 10 | parser.add_argument('image') 11 | parser.add_argument('--top_n', type=int, default=10) 12 | 13 | 14 | def main(args): 15 | 16 | model = load_model(args.model) 17 | 18 | classes = [] 19 | with open(args.classes, 'r') as f: 20 | classes = list(map(lambda x: x.strip(), f.readlines())) 21 | 22 | img = image.load_img(args.image, target_size=(299, 299)) 23 | x = image.img_to_array(img) 24 | x = np.expand_dims(x, axis=0) 25 | x = preprocess_input(x) 26 | 27 | # predict 28 | pred = model.predict(x)[0] 29 | result = [(classes[i], float(pred[i]) * 100.0) for i in range(len(pred))] 30 | result.sort(reverse=True, key=lambda x: x[1]) 31 | for i in range(args.top_n): 32 | (class_name, prob) = result[i] 33 | print("Top %d =" % (i + 1)) 34 | print("Class: %s" % (class_name)) 35 | print("Probability: %.2f%%" % (prob)) 36 | 37 | 38 | if __name__ == '__main__': 39 | args = parser.parse_args() 40 | main(args) 41 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import youtube_dl 5 | import sys 6 | import warnings 7 | from os.path import join 8 | from signal import signal, SIGINT, SIG_DFL 9 | 10 | from detection import test_full_image_network 11 | 12 | def warn(*args, **kwargs): 13 | pass 14 | 15 | warnings.warn = warn 16 | 17 | def banner(): 18 | print("[Fake Video Detector]") 19 | 20 | 21 | def main(): 22 | signal(SIGINT, SIG_DFL) 23 | 24 | p = argparse.ArgumentParser( 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | p.add_argument('--model_path', '-mi',dest='model', type=str, default='./models/x-model23.p') 27 | p.add_argument('--output_path', '-o',dest='videoOut', type=str, default='.') 28 | p.add_argument('--start_frame', type=int, default=0) 29 | p.add_argument('--end_frame', type=int, default=None) 30 | p.add_argument('--cuda', action='store_true') 31 | p.add_argument('--fast', action='store_true') 32 | requiredNamed = p.add_argument_group('required arguments') 33 | requiredNamed.add_argument('--video_path', '-i', dest='videoIn', type=str, required=True) 34 | args = p.parse_args() 35 | 36 | video_path = args.videoIn 37 | 38 | prediction = None 39 | 40 | if video_path.endswith('.mp4'): #Take direct video file 41 | prediction = test_full_image_network(args.videoIn,args.model,args.videoOut,args.fast) 42 | else: # Download video from youtube-dl supported websites 43 | video_url = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\), ]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', args.videoIn) 44 | if video_url: 45 | default_path = 'video.mp4' 46 | filename = "" 47 | ydl_opts = {'outtmpl':default_path} 48 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 49 | info = ydl.extract_info(video_url[0], download=True) 50 | filename = ydl.prepare_filename(info) 51 | prediction = test_full_image_network(filename,args.model,args.videoOut, args.fast) 52 | os.remove(filename) 53 | else: 54 | print("Not valid input format") 55 | sys.exit(-1) 56 | 57 | print("Prediction of it being fake: " + str(prediction["score"])) 58 | print("Output video in: " + prediction["file"]) 59 | 60 | if __name__ == '__main__': 61 | banner() 62 | main() 63 | -------------------------------------------------------------------------------- /models/x-model23.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i3p9/deepfake-detection-with-xception/810313758d166105d18d3ec13b5c88673014a332/models/x-model23.p -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.4.0 2 | astor==0.7.1 3 | certifi==2018.11.29 4 | cffi==1.12.1 5 | Click==7.0 6 | cmake==3.12.0 7 | dlib==19.15.0 8 | face-recognition==1.2.3 9 | face-recognition-models==0.3.0 10 | ffmpy==0.2.2 11 | ffmpeg 12 | gast==0.2.0 13 | grpcio==1.14.1 14 | h5py==2.10.0 15 | icc-rt==2019.0 16 | intel-numpy==1.15.1 17 | intel-openmp==2019.0 18 | Keras==2.2.0 19 | Keras-Applications==1.0.2 20 | Keras-Preprocessing==1.0.1 21 | Markdown==2.6.11 22 | mkl==2019.0 23 | mkl-fft==1.0.6 24 | mkl-random==1.0.1.1 25 | munch==2.3.2 26 | numpy==1.17.4 27 | olefile==0.46 28 | opencv-python==3.4.1.15 29 | pathlib==1.0.1 30 | Pillow==7.2.0 31 | pretrainedmodels==0.7.4 32 | protobuf==3.6.1 33 | pycparser==2.19 34 | PyYAML==5.1.1 35 | scandir==1.7 36 | scipy==1.3.0 37 | six==1.12.0 38 | tbb==2019.0 39 | tbb4py==2019.0 40 | termcolor==1.1.0 41 | tensorflow-gpu==1.15.0 42 | torch==1.0.1.post2 43 | torchvision==0.2.1 44 | tqdm==4.25.0 45 | youtube-dl -------------------------------------------------------------------------------- /train_dateset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import argparse 4 | import matplotlib 5 | import imghdr 6 | import pickle as pkl 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from keras.applications.xception import Xception, preprocess_input 10 | from keras.optimizers import Adam 11 | from keras.preprocessing import image 12 | from keras.losses import categorical_crossentropy 13 | from keras.layers import Dense, GlobalAveragePooling2D 14 | from keras.models import Model 15 | from keras.utils import to_categorical 16 | from keras.callbacks import ModelCheckpoint 17 | 18 | matplotlib.use('Agg') 19 | current_directory = os.path.dirname(os.path.abspath(__file__)) 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('dataset_root') 22 | parser.add_argument('classes') 23 | parser.add_argument('result_root') 24 | parser.add_argument('--epochs_pre', type=int, default=10) 25 | parser.add_argument('--epochs_fine', type=int, default=30) 26 | parser.add_argument('--batch_size_pre', type=int, default=32) 27 | parser.add_argument('--batch_size_fine', type=int, default=16) 28 | parser.add_argument('--lr_pre', type=float, default=5e-3) 29 | parser.add_argument('--lr_fine', type=float, default=5e-4) 30 | parser.add_argument('--snapshot_period_pre', type=int, default=1) 31 | parser.add_argument('--snapshot_period_fine', type=int, default=1) 32 | parser.add_argument('--split', type=float, default=0.8) 33 | 34 | 35 | def generate_from_paths_and_labels( 36 | input_paths, labels, batch_size, input_size=(299, 299)): 37 | num_samples = len(input_paths) 38 | while 1: 39 | perm = np.random.permutation(num_samples) 40 | input_paths = input_paths[perm] 41 | labels = labels[perm] 42 | for i in range(0, num_samples, batch_size): 43 | inputs = list(map( 44 | lambda x: image.load_img(x, target_size=input_size), 45 | input_paths[i:i+batch_size] 46 | )) 47 | inputs = np.array(list(map( 48 | lambda x: image.img_to_array(x), 49 | inputs 50 | ))) 51 | inputs = preprocess_input(inputs) 52 | yield (inputs, labels[i:i+batch_size]) 53 | 54 | 55 | def main(args): 56 | 57 | epochs = args.epochs_pre + args.epochs_fine 58 | args.dataset_root = os.path.expanduser(args.dataset_root) 59 | args.result_root = os.path.expanduser(args.result_root) 60 | args.classes = os.path.expanduser(args.classes) 61 | 62 | # load class names 63 | with open(args.classes, 'r') as f: 64 | classes = f.readlines() 65 | classes = list(map(lambda x: x.strip(), classes)) 66 | num_classes = len(classes) 67 | 68 | # make input_paths and labels 69 | input_paths, labels = [], [] 70 | for class_name in os.listdir(args.dataset_root): 71 | class_root = os.path.join(args.dataset_root, class_name) 72 | class_id = classes.index(class_name) 73 | for path in os.listdir(class_root): 74 | path = os.path.join(class_root, path) 75 | if imghdr.what(path) is None: 76 | # this is not an image file 77 | continue 78 | input_paths.append(path) 79 | labels.append(class_id) 80 | 81 | # convert to one-hot-vector format 82 | labels = to_categorical(labels, num_classes=num_classes) 83 | 84 | # convert to numpy array 85 | input_paths = np.array(input_paths) 86 | 87 | # shuffle dataset 88 | perm = np.random.permutation(len(input_paths)) 89 | labels = labels[perm] 90 | input_paths = input_paths[perm] 91 | 92 | # split dataset for training and validation 93 | border = int(len(input_paths) * args.split) 94 | train_labels = labels[:border] 95 | val_labels = labels[border:] 96 | train_input_paths = input_paths[:border] 97 | val_input_paths = input_paths[border:] 98 | print("Training on %d images and labels" % (len(train_input_paths))) 99 | print("Validation on %d images and labels" % (len(val_input_paths))) 100 | 101 | if os.path.exists(args.result_root) is False: 102 | os.makedirs(args.result_root) 103 | 104 | # Build a custom Xception 105 | # from pre-trained Xception model 106 | # the default input shape is (299, 299, 3) 107 | base_model = Xception( 108 | include_top=False, 109 | weights='imagenet', 110 | input_shape=(299, 299, 3)) 111 | 112 | # create a custom top classifier 113 | x = base_model.output 114 | x = GlobalAveragePooling2D()(x) 115 | x = Dense(1024, activation='relu')(x) 116 | predictions = Dense(num_classes, activation='softmax')(x) 117 | model = Model(inputs=base_model.inputs, outputs=predictions) 118 | 119 | # Train only the top classifier 120 | # freeze the body layers 121 | for layer in base_model.layers: 122 | layer.trainable = False 123 | 124 | # compile model 125 | model.compile( 126 | loss=categorical_crossentropy, 127 | optimizer=Adam(lr=args.lr_pre), 128 | metrics=['accuracy'] 129 | ) 130 | 131 | # train 132 | hist_pre = model.fit_generator( 133 | generator=generate_from_paths_and_labels( 134 | input_paths=train_input_paths, 135 | labels=train_labels, 136 | batch_size=args.batch_size_pre 137 | ), 138 | steps_per_epoch=math.ceil( 139 | len(train_input_paths) / args.batch_size_pre), 140 | epochs=args.epochs_pre, 141 | validation_data=generate_from_paths_and_labels( 142 | input_paths=val_input_paths, 143 | labels=val_labels, 144 | batch_size=args.batch_size_pre 145 | ), 146 | validation_steps=math.ceil( 147 | len(val_input_paths) / args.batch_size_pre), 148 | verbose=1, 149 | callbacks=[ 150 | ModelCheckpoint( 151 | filepath=os.path.join( 152 | args.result_root, 153 | 'model_pre_ep{epoch}_valloss{val_loss:.3f}.h5'), 154 | period=args.snapshot_period_pre, 155 | ), 156 | ], 157 | ) 158 | model.save(os.path.join(args.result_root, 'model_pre_final.h5')) 159 | 160 | # Train the whole model 161 | for layer in model.layers: 162 | layer.trainable = True #all layers are set as trainable 163 | 164 | # recompile 165 | model.compile( 166 | optimizer=Adam(lr=args.lr_fine), 167 | loss=categorical_crossentropy, 168 | metrics=['accuracy']) 169 | 170 | # train 171 | hist_fine = model.fit_generator( 172 | generator=generate_from_paths_and_labels( 173 | input_paths=train_input_paths, 174 | labels=train_labels, 175 | batch_size=args.batch_size_fine 176 | ), 177 | steps_per_epoch=math.ceil( 178 | len(train_input_paths) / args.batch_size_fine), 179 | epochs=args.epochs_fine, 180 | validation_data=generate_from_paths_and_labels( 181 | input_paths=val_input_paths, 182 | labels=val_labels, 183 | batch_size=args.batch_size_fine 184 | ), 185 | validation_steps=math.ceil( 186 | len(val_input_paths) / args.batch_size_fine), 187 | verbose=1, 188 | callbacks=[ 189 | ModelCheckpoint( 190 | filepath=os.path.join( 191 | args.result_root, 192 | 'model_fine_ep{epoch}_valloss{val_loss:.3f}.h5'), 193 | period=args.snapshot_period_fine, 194 | ), 195 | ], 196 | ) 197 | model.save(os.path.join(args.result_root, 'model_fine_final.h5')) 198 | 199 | # Create result graphs 200 | acc = hist_pre.history['accuracy'] 201 | val_acc = hist_pre.history['val_accuracy'] 202 | loss = hist_pre.history['loss'] 203 | val_loss = hist_pre.history['val_loss'] 204 | acc.extend(hist_fine.history['accuracy']) 205 | val_acc.extend(hist_fine.history['val_accuracy']) 206 | loss.extend(hist_fine.history['loss']) 207 | val_loss.extend(hist_fine.history['val_loss']) 208 | 209 | # save graph image 210 | plt.plot(range(epochs), acc, marker='.', label='accuracy') 211 | plt.plot(range(epochs), val_acc, marker='.', label='val_accuracy') 212 | plt.legend(loc='best') 213 | plt.grid() 214 | plt.xlabel('epoch') 215 | plt.ylabel('accuracy') 216 | plt.savefig(os.path.join(args.result_root, 'accuracy.png')) 217 | plt.clf() 218 | 219 | plt.plot(range(epochs), loss, marker='.', label='loss') 220 | plt.plot(range(epochs), val_loss, marker='.', label='val_loss') 221 | plt.legend(loc='best') 222 | plt.grid() 223 | plt.xlabel('epoch') 224 | plt.ylabel('loss') 225 | plt.savefig(os.path.join(args.result_root, 'loss.png')) 226 | plt.clf() 227 | 228 | # save plot data 229 | plot = { 230 | 'accuracy': acc, 231 | 'val_accuracy': val_acc, 232 | 'loss': loss, 233 | 'val_loss': val_loss, 234 | } 235 | with open(os.path.join(args.result_root, 'plot.dump'), 'wb') as f: 236 | pkl.dump(plot, f) 237 | 238 | 239 | if __name__ == '__main__': 240 | args = parser.parse_args() 241 | main(args) 242 | --------------------------------------------------------------------------------