├── LICENSE ├── MindSpore ├── config.py ├── requirements.txt ├── train_ffpp.py ├── training │ ├── dataset.py │ ├── evaluate.py │ ├── losses.py │ ├── network.py │ └── r3d.py └── utils │ ├── graph_conv.py │ └── util.py ├── Pytorch ├── config.py ├── requirements.txt ├── test.py ├── train_ffpp.py ├── training │ ├── dataset.py │ ├── evaluate.py │ ├── network.py │ └── r3d.py └── utils │ ├── graph_conv.py │ └── util.py └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MindSpore/config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | def __init__(self, datalabel, recipes=[], **params): 3 | self.datalabel = datalabel 4 | self.workers = 8 5 | self.start_epoch = 0 6 | self.epochs = 100 7 | self.max_epoch = 196 8 | self.mask_strategy = 'min' 9 | self.mask_rate = 0.5 10 | 11 | self.lr = 2e-4 12 | self.weight_decay = 1e-6 13 | self.bn_weight_decay = 0.0 14 | self.scheduler_step = 5 15 | self.scheduler_gamma = 0.5 16 | self.momentum = 0.9 17 | self.adam_betas = (0.9, 0.999) 18 | self.optimizer = 'adamw' 19 | 20 | self.lr_policy = 'cosine' 21 | self.warmup_epochs = 30 22 | self.warmup_lr = 1e-6 23 | 24 | self.batch_size = 16 25 | self.resize=112 26 | self.a_dim = 8 27 | self.h_dim = 64 28 | self.embed_dim = 768 29 | 30 | self.dropout_rate = 0.1 31 | 32 | for i in recipes: 33 | self.recipe(i) 34 | 35 | for i in params: 36 | self.__setattr__(i, params[i]) 37 | 38 | self.train_dataset = dict(datalabel=self.datalabel, resize=self.resize, augment=self.augment) 39 | self.val_dataset = dict(resize=self.resize, augment='augment_test') 40 | 41 | def recipe(self, name): 42 | if 'vit' in name: 43 | self.batch_size=8 44 | if 'ff' in name: 45 | if 'ff-5' in name: 46 | self.num_classes=5 47 | self.augment='augment0' 48 | if 'celeb' in name: 49 | self.augment='augment0' 50 | if 'dfdc' in name: 51 | self.augment='augment0' 52 | if 'uadfv' in name: 53 | self.augment='augment0' 54 | if 'xception' in name: 55 | self.net='xception' 56 | self.batch_size=64 57 | 58 | if 'efficient' in name: 59 | self.net=name 60 | self.batch_size=10 61 | scale=int(name.split('b')[-1]) 62 | sizes=[224, 240, 260, 300, 380, 456, 528, 600, 672] 63 | self.resize=sizes[scale] 64 | if 'r3d' in name: 65 | self.batch_size=40 66 | if 'resnet' in name: 67 | self.batch_size=48 68 | if 'addnet' in name: 69 | self.batch_size=60 70 | if 'resnet18' in name: 71 | self.batch_size=2048 72 | if 'no-aug' in name: 73 | self.augment='augment_test' 74 | -------------------------------------------------------------------------------- /MindSpore/requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.3.0 2 | einops==0.6.0 3 | mindspore==1.10.0 4 | mindspore_gpu==1.10.1 5 | numpy==1.22.3 6 | scikit_learn==1.2.2 7 | scipy==1.7.3 -------------------------------------------------------------------------------- /MindSpore/train_ffpp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from os.path import join 4 | import argparse 5 | 6 | import logging 7 | import numpy as np 8 | 9 | import mindspore 10 | import mindspore.dataset as ds 11 | from mindspore import nn, ops 12 | from mindspore.communication import init, get_rank, get_group_size 13 | from mindspore.train import Model 14 | 15 | from training.losses import CustomLossCell 16 | 17 | from training.dataset import * 18 | from training.network import define_network 19 | 20 | from utils.util import * 21 | from config import Config 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--checkpoint', type=str, default='checkpoint/') 26 | parser.add_argument('--results', type=str, default='results/') 27 | parser.add_argument('--resume', type=str, default=None) 28 | 29 | parser.add_argument('--lambda_bce', type=float, default=1., help="1") 30 | 31 | parser.add_argument('--print_iter', type=int, default=100, help='print frequency') 32 | parser.add_argument('--save_epoch', default=1, type=int) 33 | 34 | parser.add_argument('--amp', action='store_true', help='if True, use fp16.') 35 | parser.add_argument('--local_rank', type=int, default=0) 36 | parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:23504') 37 | parser.add_argument('--world_size', type=int, default=1) 38 | parser.add_argument('--gpu', type=int, default=None, help='if DDP, set None.') 39 | parser.add_argument('--multiprocessing_distributed', type=bool, default=False, 40 | help='Use multi-processing distributed training to launch ' 41 | 'N processes per node, which has N GPUs. This is the ' 42 | 'fastest way to use multiprocessing for either single node or ' 43 | 'multi node data parallel training') 44 | # +---------------------------------------------------------+ 45 | 46 | def setup_seed(seed): 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | 50 | # +---------------------------------------------------------+ 51 | """ 52 | Setup Configuration 53 | """ 54 | config = Config(datalabel='ff-all', recipes=['ff-all','r3d'], ckpt='CKPT_FILE_NAME.pth.tar', best_ckpt='BEST_CKPT_FILE_NAME.pth.tar', epochs=100) 55 | best_auc1 = 0 56 | 57 | def main(): 58 | args = parser.parse_args() 59 | setup_seed(1000) 60 | os.makedirs(args.checkpoint, exist_ok=True) 61 | os.makedirs(args.results, exist_ok=True) 62 | 63 | if args.dist_url == "env://" and args.world_size == -1: 64 | args.world_size = int(os.environ["WORLD_SIZE"]) 65 | 66 | if args.multiprocessing_distributed: 67 | print("args.multiprocessing_distributed==True") 68 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 69 | 70 | 71 | if args.multiprocessing_distributed: 72 | print('multiprocessing_distributed') 73 | init() 74 | ngpus_per_node = mindspore.communication.get_group_size() 75 | args.world_size = ngpus_per_node * args.world_size 76 | mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.HYBRID_PARALLEL) 77 | else: 78 | print('No Distributed Data Processing. GPU => ', args.gpu) 79 | main_worker(args.gpu, 1, args) 80 | 81 | def main_worker(gpu, ngpus_per_node, args): 82 | global best_auc1 83 | if args.local_rank == 0: 84 | logging.basicConfig(filename=join(args.results, 'train.log'), filemode='w', format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) 85 | args.gpu = gpu 86 | if args.gpu is not None: 87 | print('Use GPU: {} for training.'.format(args.gpu)) 88 | if args.distributed: 89 | print("distributed == True") 90 | if args.dist_url == "env://" and args.local_rank == -1: 91 | args.local_rank = int(os.environ["RANK"]) 92 | if args.multiprocessing_distributed: 93 | # For multiprocessing distributed training, rank needs to be the 94 | # global rank among all the processes 95 | args.local_rank = args.local_rank * ngpus_per_node + gpu 96 | mindspore.communication.init(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.local_rank) 97 | 98 | 99 | if args.distributed: 100 | if args.gpu is not None: 101 | args.batch_size = int(config.batch_size / ngpus_per_node) 102 | args.workers = int((config.workers + ngpus_per_node -1) / ngpus_per_node) 103 | mindspore.set_context(args.gpu) 104 | print('GPU:', args.gpu) 105 | 106 | else: 107 | ''' 108 | DistributedDataParallel will divide and allocate batch_size to all 109 | available GPUs if device_ids are not set 110 | ''' 111 | mindspore.set_auto_parallel_context(dataset_strategy="full_batch") 112 | print('Allocate batch-size to all available GPUs') 113 | 114 | elif args.gpu is not None: 115 | # Single GPU. 116 | mindspore.set_context(device_id=args.gpu, device_target="GPU") 117 | 118 | print("Use Single GPU", args.gpu) 119 | else: 120 | print(">>> Use CPU") 121 | mindspore.set_context(device_target="CPU") 122 | 123 | model = define_network(f_dim=512, h_didm=128, a_dim=12, config=config) 124 | 125 | if args.resume: 126 | if args.local_rank % ngpus_per_node == 0: 127 | print('=> Use previously trained model from {}'.format(args.resume)) 128 | 129 | config.start_epoch, best_auc1 = load_model(model, join(args.resume, f'{config.best_ckpt}')) 130 | opt = mindspore.nn.AdamWeightDecay(model.trainable_params(), learning_rate=config.lr) 131 | milestone = [item for item in range(0, config.epochs, config.scheduler_step)] 132 | learning_rates = [] 133 | init_gamma = 1 134 | for _ in range(config.epochs // config.scheduler_step): 135 | learning_rates.append(init_gamma) 136 | init_gamma *= 0.5 137 | scheduler = mindspore.nn.piecewise_constant_lr(milestone[1:], learning_rates[1:]) 138 | 139 | criterion = mindspore.nn.CrossEntropyLoss() 140 | 141 | if config.train_dataset['datalabel'] == 'celeb': 142 | dataset_generator = DFDataset(phase='train', tag="", codec='', **config.train_dataset) 143 | dataset = ds.GeneratorDataset(dataset_generator, ["img", "label"], shuffle=False) 144 | train_data, val_data = dataset.split([0.8, 0.2]) 145 | 146 | else: 147 | train_data = DFDataset(phase='train', tag="", codec='c23', **config.train_dataset) 148 | val_data = DFDataset(phase='val', datalabel="ff-all", tag="", codec='c23', **config.val_dataset) 149 | 150 | if args.distributed: 151 | train_sampler = ds.DistributedSampler(train_data, shuffle=True) 152 | val_sampler = ds.DistributedSampler(val_data, shuffle=True) 153 | else: 154 | train_sampler = val_sampler = None 155 | 156 | train_dataloader = ds.GeneratorDataset(train_data, ["img", "label"], shuffle=True, sampler=train_sampler) 157 | train_dataloader = train_dataloader.batch(batch_size=config.batch_size) 158 | train_dataloader = train_dataloader.map(operations=mindspore.dataset.transforms.TypeCast(mindspore.int32), input_columns="label") 159 | 160 | val_dataloader = ds.GeneratorDataset(val_data, ["img", "label"], shuffle=(val_sampler is None), sampler=val_sampler) 161 | val_dataloader = val_dataloader.batch(batch_size=4) 162 | val_dataloader = val_dataloader.map(operations=mindspore.dataset.transforms.TypeCast(mindspore.int32), input_columns="label") 163 | # model = Model(model, criterion, opt) 164 | 165 | for epoch in range(config.start_epoch, config.epochs): 166 | train(train_dataloader, model, opt, criterion, epoch, args) 167 | val_auc = valid(val_dataloader, model, criterion, ngpus_per_node, epoch, args) 168 | scheduler.step() 169 | 170 | is_best = val_auc > best_auc1 171 | best_auc1 = max(val_auc, best_auc1) 172 | 173 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.local_rank % ngpus_per_node==0): 174 | save_checkpoint({ 175 | 'epoch': epoch + 1, 176 | 'state_dict': model.state_dict(), 177 | 'best_auc1': best_auc1 178 | }, 179 | is_best, args.checkpoint, filename=config.ckpt, best=config.best_ckpt) 180 | 181 | def train(train_dataloader, model, optimizer, criterion, epoch, args): 182 | # loss_net = CustomLossCell(model, criterion) 183 | def forward_fn(data, label): 184 | logits = model(data) 185 | loss = criterion(logits, label) 186 | return loss, logits 187 | 188 | grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) 189 | 190 | def train_step(data, label): 191 | (loss, _), grads = grad_fn(data, label) 192 | loss = ops.depend(loss, optimizer(grads)) 193 | return loss 194 | 195 | model.set_train() 196 | size = train_dataloader.get_dataset_size() 197 | for index, (data, label) in enumerate(train_dataloader.create_tuple_iterator()): 198 | loss = train_step(data, label) 199 | if index % 100 == 0: 200 | loss, current = loss.asnumpy(), index 201 | print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]") 202 | return loss 203 | 204 | def valid(val_dataloader, model, criterion, ngpus_per_node, epoch, args): 205 | model = Model(model, loss_fn=criterion, optimizer=None, metrics={'acc'}) 206 | acc = model.eval(val_dataloader, dataset_sink_mode=False) 207 | 208 | return acc 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /MindSpore/training/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | 6 | import random 7 | import sklearn 8 | import mindspore 9 | from mindspore.dataset import GeneratorDataset as Dataset 10 | from albumentations import Compose, Resize, HorizontalFlip, Normalize 11 | # from mindspore.dataset.transforms import Compose 12 | # from mindspore.dataset import vision 13 | import numpy as np 14 | # Global 15 | seq_len = 100 16 | 17 | def augmentation(type='augment0', resize=112): 18 | augment0 = Compose([Resize(resize, resize),HorizontalFlip(),Normalize(mean=(0.43216,0.394666,0.37645),std=(0.22803,0.22145,0.216989))]) 19 | augment_test = Compose([Resize(resize, resize), Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) 20 | augment_set = {'augment0': augment0, 'augment_test': augment_test} 21 | return augment_set[type] 22 | 23 | class DFDataset: 24 | def __init__(self, phase, datalabel, resize, tag, codec, augment='augment0'): 25 | assert phase in ['train', 'val', 'test'] 26 | self.datalabel = datalabel 27 | self.resize = resize 28 | self.phase = phase 29 | self.epoch = 0 30 | self.len = 0 31 | self.fake = [] 32 | self.real = [] 33 | self.dataset = [] 34 | self.aug = augmentation(augment, resize) 35 | 36 | if phase == 'train': 37 | print("------Train Set------") 38 | elif phase == 'val': 39 | print("------Validation Set------") 40 | elif phase == 'test': 41 | print("------Test Set------") 42 | else: 43 | print("Error: The phase is None") 44 | if 'ff-all' in self.datalabel: 45 | if tag == "": 46 | print("Load ff-all") 47 | for subtag in ['deepfakes', 'face2face', 'faceswap', 'neural_textures', 'original', 'FaceShifter']: 48 | subdataset = FF_dataset(subtag, codec, phase) 49 | self.dataset += subdataset 50 | if len(subdataset) > 0: 51 | print(f'load {subtag}-{codec} len: {len(subdataset)}') 52 | else: 53 | for subtag in ['deepfakes', 'face2face', 'faceswap', 'neural_textures', 'original']: 54 | if tag != subtag: 55 | subdataset = FF_dataset(subtag, codec, phase) 56 | self.dataset += subdataset 57 | if len(subdataset) > 0: 58 | print(f'load {subtag}-{codec} len: {len(subdataset)}') 59 | if phase != 'test': 60 | self.dataset = make_balance(self.dataset) 61 | 62 | elif 'ff' in self.datalabel: 63 | self.dataset = FF_dataset(tag, codec, phase) 64 | self.dataset += FF_dataset("original", codec, phase) 65 | print(f'load {tag}-{codec} len: {len(self.dataset)}') 66 | 67 | elif 'celeb' in self.datalabel: 68 | self.dataset = CelebDF(phase) 69 | print(f'load {self.datalabel} len: {len(self.dataset)}') 70 | elif 'dfdc' in self.datalabel: 71 | self.dataset = DFDC(phase) 72 | print(f'load {self.datalabel} len: {len(self.dataset)}') 73 | 74 | else: 75 | raise(Exception(f'Error: Dataset {self.datalabel} does not exist!')) 76 | self.len = len(self.dataset) 77 | 78 | def __getitem__(self, index): 79 | fpath_list, label = self.dataset[index] 80 | 81 | # Total number of sampled frames. 82 | len_list = len(fpath_list) 83 | frame_N = len_list 84 | 85 | buffer = np.empty(shape=(seq_len, self.resize, self.resize, 3), dtype=np.float64) 86 | idx = 0 87 | for idx, i in enumerate(range(frame_N)): 88 | fpath = fpath_list[i] 89 | img = cv2.imread(fpath) 90 | img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 91 | img = self.aug(image=img)["image"] 92 | buffer[idx] = img 93 | idx += 1; cur_idx = 0 94 | while idx < seq_len: 95 | buffer[idx] = buffer[cur_idx % frame_N] 96 | cur_idx += 1; idx += 1 97 | buffer = self.ToTensor(buffer) 98 | return buffer, label 99 | # return {'img': buffer, 'label': label} 100 | 101 | def __len__(self): 102 | return self.len 103 | 104 | def ToTensor(self, ndarray): 105 | tensor = ndarray.transpose((3, 0, 1, 2)) 106 | return mindspore.Tensor.from_numpy(tensor).astype(mindspore.float32) 107 | 108 | def make_balance(data): 109 | tr = list(filter(lambda x:x[1]==0, data)) 110 | tf = list(filter(lambda x:x[1]==1, data)) 111 | if len(tr) > len(tf): 112 | tr, tf = tf, tr 113 | rate = len(tf) // len(tr) 114 | res = len(tf) - rate * len(tr) 115 | tr = tr * rate + random.sample(tr,res) 116 | return tr + tf 117 | 118 | ## Face Forensics++ 119 | ffpp_raw_train = "/data2/ziming.yang/datasets/ffpp/ffpp_raw_train.json" 120 | ffpp_raw_val = "/data2/ziming.yang/datasets/ffpp/ffpp_raw_val.json" 121 | ffpp_raw_test = "/data2/ziming.yang/datasets/ffpp/ffpp_raw_test.json" 122 | ffpp_c23_train = "/data2/ziming.yang/datasets/ffpp/ffpp_c23_train.json" 123 | ffpp_c23_val = "/data2/ziming.yang/datasets/ffpp/ffpp_c23_val.json" 124 | ffpp_c23_test = "/data2/ziming.yang/datasets/ffpp/ffpp_c23_test.json" 125 | ffpp_c40_train = "/data2/ziming.yang/datasets/ffpp/ffpp_c40_train.json" 126 | ffpp_c40_val = "/data2/ziming.yang/datasets/ffpp/ffpp_c40_val.json" 127 | ffpp_c40_test = "/data2/ziming.yang/datasets/ffpp/ffpp_c40_test.json" 128 | 129 | ffpp_proto = {"raw": {'train': ffpp_raw_train, 'val': ffpp_raw_val, 'test': ffpp_raw_test}, 130 | "c23": {'train': ffpp_c23_train, 'val': ffpp_c23_val, 'test': ffpp_c23_test}, 131 | "c40": {'train': ffpp_c40_train, 'val': ffpp_c40_val, 'test': ffpp_c40_test} 132 | } 133 | def FF_dataset(tag, codec, phase='train'): 134 | assert(tag in ['deepfakes', 'face2face', 'faceswap', 'neural_textures', 'original', 'FaceShifter']) 135 | assert(codec in ['raw','c23','c40','all']) 136 | assert(phase in ['train','val','test','all']) 137 | if phase=="all": 138 | return FF_dataset(tag, codec, 'train') + FF_dataset(tag, codec, 'val') + FF_dataset(tag, codec, 'test') 139 | if codec=="all": 140 | return FF_dataset(tag,'raw', phase) + FF_dataset(tag,'c23',phase) + FF_dataset(tag,'c40',phase) 141 | 142 | _dataset = [] 143 | path = ffpp_proto[codec][phase] 144 | with open(path, 'r') as f: 145 | data_dict = json.load(f) 146 | for k,v in data_dict.items(): 147 | if tag not in k: 148 | continue 149 | video_label = v['label'] 150 | file_list = v['list'] 151 | 152 | for i in range(0, len(file_list), seq_len): 153 | if i+seq_len>=len(file_list): 154 | _dataset.append([file_list[-seq_len:], video_label]) 155 | else: 156 | _dataset.append([file_list[i:i+ seq_len], video_label]) 157 | return _dataset 158 | 159 | ## Celeb-DF v2 160 | celebtrain = "/data2/ziming.yang/datasets/Celeb-DF/celeb_train.json" 161 | celebtest = "/data2/ziming.yang/datasets/Celeb-DF/celeb_test.json" 162 | celeb_proto = {'train': celebtrain, 'test': celebtest} 163 | def CelebDF(phase='train'): 164 | assert(phase in ['train', 'test', 'all']) 165 | if phase=='all': 166 | return CelebDF('train') + CelebDF('tests') 167 | _dataset = [] 168 | path = celeb_proto[phase] 169 | with open(path, 'r') as f: 170 | data_dict = json.load(f) 171 | num_frames = 100 if phase!='train' else None 172 | for k,v in data_dict.items(): 173 | if len(v['list']) != 30: ## Dataset analysis 174 | continue 175 | video_label = v['label'] 176 | file_list = v['list'][:num_frames] 177 | 178 | len_list = len(file_list) 179 | for i in range(0, len_list, seq_len): 180 | if i+seq_len >= len_list: 181 | _dataset.append([file_list[-seq_len:], video_label]) 182 | else: 183 | _dataset.append([file_list[i:i+seq_len], video_label]) 184 | 185 | return _dataset 186 | 187 | ## Deepfakes Detection Challenge 188 | dfdctrain = "/data2/ziming.yang/datasets/DFDC/dfdc_train.json" 189 | dfdcval = "/data2/ziming.yang/datasets/DFDC/dfdc_val.json" 190 | dfdctest = "/data2/ziming.yang/datasets/DFDC/dfdc_test.json" 191 | dfdc_proto = {'train': dfdctrain, 'val': dfdcval, 'test': dfdctest} 192 | def DFDC(phase='train'): 193 | assert(phase in ['train', 'val', 'test', 'all']) 194 | if phase=='all': 195 | return DFDC('train') + DFDC('val') + DFDC('tests') 196 | _dataset = [] 197 | path = dfdc_proto[phase] 198 | num_frames = 100 if phase!='train' else None 199 | with open(path, 'r') as f: 200 | data_dict = json.load(f) 201 | for k,v in data_dict.items(): 202 | if len(v['list']) < 100: 203 | continue 204 | 205 | video_label = v['label'] 206 | 207 | file_list = v['list'][:num_frames] 208 | 209 | len_list = len(file_list) 210 | 211 | for i in range(0, len_list, seq_len): 212 | if i+seq_len >= len_list: 213 | _dataset.append([file_list[-seq_len:], video_label]) 214 | else: 215 | _dataset.append([file_list[i:i+ seq_len], video_label]) 216 | return _dataset -------------------------------------------------------------------------------- /MindSpore/training/evaluate.py: -------------------------------------------------------------------------------- 1 | import mindspore 2 | from mindspore import ops 3 | from scipy.optimize import brentq 4 | from sklearn import metrics 5 | from sklearn.metrics import roc_curve 6 | from scipy.interpolate import interp1d 7 | 8 | def accuracy(output, target, topk=(1,)): 9 | """Computes the accuracy over the k top predictions for the specified values of k""" 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 20 | res.append(correct_k.mul_(100.0 / batch_size)) 21 | return res 22 | 23 | def calculate_eer(y_true, y_score): 24 | fpr, tpr, thresholds = roc_curve(y_true, y_score) 25 | eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 26 | thresh = interp1d(fpr, thresholds)(eer) 27 | return eer, thresh 28 | 29 | def compute_video_level_auc(video_to_logits, video_to_labels): 30 | """ " 31 | Compute video-level area under ROC curve. Averages the logits across the video for non-overlapping clips. 32 | 33 | Parameters 34 | ---------- 35 | video_to_logits : dict 36 | Maps video ids to list of logit values 37 | video_to_labels : dict 38 | Maps video ids to label 39 | """ 40 | output_batch = torch.stack( 41 | [torch.mean(torch.stack(video_to_logits[video_id]), 0, keepdim=False) for video_id in video_to_logits.keys()] 42 | ) 43 | output_labels = torch.stack([video_to_labels[video_id] for video_id in video_to_logits.keys()]) 44 | 45 | fpr, tpr, _ = metrics.roc_curve(output_labels.cpu().numpy(), output_batch.cpu().numpy()) 46 | return metrics.auc(fpr, tpr) 47 | 48 | def compute_video_level_acc(video_to_logits, video_to_labels): 49 | output_batch = torch.stack( 50 | [torch.mean(torch.stack(video_to_logits[video_id]), 0, keepdim=False) for video_id in video_to_logits.keys()] 51 | ) 52 | prediction = (output_batch>=0.5).long() 53 | output_labels = torch.stack([video_to_labels[video_id] for video_id in video_to_logits.keys()]) 54 | acc = metrics.accuracy_score(output_labels.cpu().numpy(), prediction.cpu().numpy()) 55 | return acc 56 | 57 | def compute_video_level_prf(video_to_logits, video_to_labels): 58 | output_batch = torch.stack( 59 | [torch.mean(torch.stack(video_to_logits[video_id]), 0, keepdim=False) for video_id in video_to_logits.keys()] 60 | ) 61 | prediction = (output_batch>=0.5).long() 62 | output_labels = torch.stack([video_to_labels[video_id] for video_id in video_to_logits.keys()]) 63 | pre, rec, f1, support = metrics.precision_recall_fscore_support(output_labels.cpu().numpy(), prediction.cpu().numpy(), average='binary') 64 | return pre, rec, f1, support -------------------------------------------------------------------------------- /MindSpore/training/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Loss Functions for MindSpore platform. 3 | """ 4 | import mindspore 5 | from mindspore import nn, ops 6 | 7 | def loss_tc(T_attention): 8 | batch_size, T, A, H, W = T_attention.shape 9 | loss_tc = 0 10 | for t in range(T-1): 11 | mapi = T_attention[:, t, :, :, :] 12 | mapj = T_attention[:, t+1, :, :, :] 13 | loss_tc += ops.dist(mapi, mapj,p=1) 14 | loss_tc = loss_tc / T / batch_size 15 | return loss_tc 16 | 17 | def loss_od(T_att_features): 18 | eps=1e-8 19 | # [B, T, A, H, W] 20 | # B, T, A, C = T_att_features.shape 21 | matrix_A = T_att_features 22 | a_n = matrix_A.norm(dim=2).unsqueeze(2) 23 | # a_n [B, N, 1] 24 | # Normalize 25 | a_norm = matrix_A / ops.max(a_n, eps * ops.ones_like(a_n)) 26 | # patch-wise absolute value of cosine similarity 27 | sim_matrix = ops.einsum('abc,acd->abd', a_norm, a_norm.swapaxes(1,2)) 28 | loss_rc = sim_matrix.mean() 29 | return loss_rc 30 | 31 | class CustomLossCell(nn.Cell): 32 | def __init__(self, backbone, loss_fn): 33 | super(CustomLossCell, self).__init__(auto_prefix=False) 34 | self._backbone = backbone 35 | self.loss_ce = loss_fn 36 | 37 | def construct(self, data, label): 38 | logits, T_attention, T_att_features = self._backbone(data) 39 | loss_ce = self.loss_ce(logits, label) 40 | loss_temporal_cons = loss_tc(T_attention) 41 | loss_orthogon_dive = loss_od(T_att_features) 42 | return loss_ce+loss_orthogon_dive+loss_temporal_cons -------------------------------------------------------------------------------- /MindSpore/training/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mindspore 3 | from mindspore import nn, ops 4 | from mindspore.common.initializer import initializer, TruncatedNormal, Zero 5 | from training.r3d import mc3_18 6 | from utils.graph_conv import * 7 | from utils.util import pretrained_model 8 | from einops import rearrange 9 | 10 | def define_network(f_dim, h_didm, a_dim, config): 11 | model = Model(f_dim, h_didm, a_dim, config) 12 | return model 13 | 14 | 15 | class Model(nn.Cell): 16 | def __init__(self, f_dim, h_dim, a_dim, config): 17 | super(Model, self).__init__() 18 | self.f_dim = f_dim 19 | self.h_dim = h_dim 20 | self.a_dim = a_dim 21 | self.mask_strategy = config.mask_strategy 22 | self.masked_drop = config.mask_rate 23 | 24 | self.edges = mindspore.Parameter(random_edges(a_dim)) 25 | 26 | self.enc = Encoder() 27 | 28 | self.atn = AttentionMap(in_channels=f_dim, out_channels=a_dim) 29 | self.atp = AttentionPooling() # MaxAttentionPooling() 30 | 31 | self.cell = TGCN(f_dim, h_dim) 32 | 33 | self.flatten = nn.Flatten() 34 | self.bn = nn.BatchNorm1d(a_dim * h_dim) 35 | self.lrelu = nn.ReLU() 36 | self.fc = nn.Dense(a_dim * h_dim, 2) 37 | 38 | self.init_weights() 39 | 40 | def init_weights(self): 41 | self.fc.weight.set_data(initializer(TruncatedNormal(sigma=0.02), self.fc.weight.data.shape)) 42 | self.fc.bias.set_data(initializer(Zero(), self.fc.bias.data.shape)) 43 | 44 | 45 | def construct(self, clip): 46 | adj = self.edges 47 | if self.mask_strategy!='none': # self.edges.requires_grad and 48 | adj = MaskConnection(adj, self.mask_strategy, self.masked_drop) 49 | adj = calculate_laplacian_with_self_loop(adj) 50 | 51 | batch_size = clip.shape[0] 52 | h1 = ops.zeros((batch_size, self.a_dim, self.h_dim), clip.dtype) 53 | h2 = ops.zeros((batch_size, self.a_dim, self.h_dim), clip.dtype) 54 | 55 | out = None 56 | 57 | snippets = ops.split(clip, output_num=5, axis=2) 58 | for index, input in enumerate(snippets): 59 | ## Features 60 | feature_map = self.enc(input) 61 | 62 | attention_maps = self.atn(feature_map) 63 | feature_matrix = self.atp(feature_map, attention_maps) 64 | 65 | h1, h2 = self.cell(feature_matrix, h1, h2, adj) 66 | 67 | # if index > 0: 68 | # print("index > 0") 69 | # T_attention_maps = ops.cat((T_attention_maps, attention_maps_.unsqueeze(1)), axis=1) 70 | # T_att_features = ops.cat((T_att_features, attention_maps_.unsqueeze(1).flatten(-2)), axis=1) 71 | 72 | # else: 73 | # print("index==0") 74 | # T_attention_maps = attention_maps_.unsqueeze(1) 75 | # T_att_features = attention_maps_.unsqueeze(1).flatten(-2) 76 | 77 | out = h2 78 | # T_att_features = T_att_features.flatten(0,1) 79 | 80 | x = self.flatten(out) 81 | 82 | x = self.bn(x) 83 | x = self.lrelu(x) 84 | 85 | logits = self.fc(x) 86 | 87 | return logits#, T_attention_maps, T_att_features 88 | 89 | def loss_tc(self, T_attention): 90 | batch_size, T, A, H, W = T_attention.shape 91 | loss_tc = 0 92 | for t in range(T-1): 93 | mapi = T_attention[:, t, :, :, :] 94 | mapj = T_attention[:, t+1, :, :, :] 95 | loss_tc += ops.dist(mapi, mapj,p=1) 96 | loss_tc = loss_tc / T / batch_size 97 | return loss_tc 98 | 99 | def loss_od(self, T_att_features): 100 | eps=1e-8 101 | # [B, T, A, H, W] 102 | # B, T, A, C = T_att_features.shape 103 | matrix_A = T_att_features 104 | a_n = matrix_A.norm(dim=2).unsqueeze(2) 105 | # a_n [B, N, 1] 106 | # Normalize 107 | a_norm = matrix_A / ops.max(a_n, eps * ops.ones_like(a_n)) 108 | # patch-wise absolute value of cosine similarity 109 | sim_matrix = ops.einsum('abc,acd->abd', a_norm, a_norm.swapaxes(1,2)) 110 | loss_rc = sim_matrix.mean() 111 | return loss_rc 112 | 113 | def l2_reg(self): 114 | reg_loss = 0.0 115 | for param in self.cell.parameters(): 116 | reg_loss += ops.sum(param ** 2) / 2 117 | reg_loss = 1.5e-3 * reg_loss 118 | return reg_loss 119 | 120 | 121 | class Encoder(nn.Cell): 122 | def __init__(self): 123 | super(Encoder, self).__init__() 124 | base = mc3_18(pretrained=False) 125 | 126 | self.base = base 127 | 128 | def construct(self, clip): 129 | return self.base(clip) 130 | 131 | class AttentionMap(nn.Cell): 132 | def __init__(self, in_channels, out_channels): 133 | super(AttentionMap, self).__init__() 134 | 135 | self.num_attentions = out_channels 136 | 137 | self.conv_extract = nn.Conv3d(in_channels, in_channels, kernel_size=(7,3,3), stride=(3, 1, 1), pad_mode="pad", padding=1) 138 | self.bn1 = nn.BatchNorm3d(in_channels) 139 | self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1)) 140 | self.bn2 = nn.BatchNorm3d(out_channels) 141 | 142 | def construct(self, x): 143 | if self.num_attentions==0: 144 | return ops.ones((x.shape[0],1,1,1)) 145 | x = self.conv_extract(x) 146 | x = self.bn1(x) 147 | x = ops.relu(x) 148 | x = self.conv2(x) 149 | x = self.bn2(x) 150 | 151 | x = ops.adaptive_avg_pool3d(x, (1, None, None)) 152 | x = x.squeeze(2) 153 | gelu = ops.GeLU() 154 | x = gelu(x) 155 | return x 156 | 157 | class AttentionPooling(nn.Cell): 158 | def __init__(self): 159 | super().__init__() 160 | def construct(self, features, attentions,norm=2): 161 | H, W = features.shape[-2:] 162 | B, M, AH, AW = attentions.shape 163 | if AH != H or AW != W: 164 | attentions = ops.interpolate(attentions,size=(H,W), mode='bilinear', align_corners=True) 165 | if norm==1: 166 | attentions=attentions+1e-8 167 | if len(features.shape)==4: 168 | einsum1 = ops.Einsum('imjk,injk->imn') 169 | feature_matrix=einsum1((attentions, features)) 170 | else: 171 | einsum2 = ops.Einsum('imjk,indjk->imn') 172 | feature_matrix=einsum2((attentions, features)) 173 | if norm==1: 174 | w=ops.sum(attentions,dim=(2,3)).unsqueeze(-1) 175 | feature_matrix/=w 176 | if norm==2: 177 | l2_normalize = ops.L2Normalize(axis=-1) 178 | feature_matrix = l2_normalize(feature_matrix) 179 | if norm==3: 180 | w=ops.sum(attentions,dim=(2,3)).unsqueeze(-1)+1e-8 181 | feature_matrix/=w 182 | return feature_matrix 183 | 184 | class MaxAttentionPooling(nn.Cell): 185 | def __init__(self): 186 | super().__init__() 187 | def construct(self, features, attentions,norm=2): 188 | H, W = features.shape[-2:] 189 | 190 | B, M, AH, AW = attentions.shape 191 | if AH != H or AW != W: 192 | attentions=ops.interpolate(attentions,size=(H,W), mode='bilinear', align_corners=True) 193 | if norm==1: 194 | attentions=attentions+1e-8 195 | if len(features.shape)==4: 196 | einsum1 = ops.Einsum('imjk,injk->imnjk') 197 | feature_matrix=einsum1((attentions, features)) 198 | else: 199 | einsum2 = ops.Einsum('imjk,indjk->imnjk') 200 | feature_matrix=einsum2((attentions, features)) 201 | if norm==1: 202 | w=ops.sum(attentions,dim=(2,3)).unsqueeze(-1) 203 | feature_matrix/=w 204 | if norm==2: 205 | feature_matrix = ops.max_pool3d(feature_matrix, [1,2,2], [1,2,2]) 206 | feature_matrix = ops.max_pool3d(feature_matrix, [1,2,2], [1,2,2]) 207 | feature_matrix = feature_matrix.squeeze() 208 | l2_normalize = ops.L2Normalize(axis=-1) 209 | feature_matrix = l2_normalize(feature_matrix) 210 | if norm==3: 211 | w=ops.sum(attentions,dim=(2,3)).unsqueeze(-1)+1e-8 212 | feature_matrix/=w 213 | return feature_matrix 214 | 215 | class TGCN(nn.Cell): 216 | def __init__(self, in_dim, h_dim): 217 | super(TGCN, self).__init__() 218 | self.in_dim = in_dim 219 | self.h_dim = h_dim 220 | self.tgcn_layer1 = TGCNCell(self.in_dim, self.h_dim) 221 | self.tgcn_layer2 = TGCNCell(self.h_dim, self.h_dim) 222 | 223 | def construct(self, inputs, H1, H2, adj): 224 | batch_size, a_dim, c = inputs.shape 225 | assert self.in_dim == c 226 | H1 = self.tgcn_layer1(inputs, H1, adj) 227 | H2 = self.tgcn_layer2(H1, H2, adj) 228 | return H1, H2 229 | 230 | class TGCNCell(nn.Cell): 231 | def __init__(self, in_dim, h_dim): 232 | super(TGCNCell, self).__init__() 233 | self.in_dim = in_dim 234 | self.h_dim = h_dim 235 | 236 | self.gconv1 = GConv(in_dim + h_dim, h_dim * 2, bias=True) 237 | self.gconv2 = GConv(in_dim + h_dim, h_dim, bias=False) 238 | 239 | def construct(self, inputs, hidden_state, adj): 240 | concatenation = ops.sigmoid(self.gconv1(inputs, hidden_state, adj)) 241 | 242 | r, u = ops.split(concatenation, output_num=2, axis=2) 243 | 244 | c = ops.tanh(self.gconv2(inputs, r * hidden_state, adj)) 245 | 246 | new_hidden_state = u * hidden_state + (1.0 - u) * c 247 | 248 | return new_hidden_state 249 | 250 | class GConv(nn.Cell): 251 | """ 252 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 253 | """ 254 | def __init__(self, in_features, out_features, bias=True): 255 | super(GConv, self).__init__() 256 | self.in_features = in_features 257 | self.out_features = out_features 258 | 259 | self.weight = mindspore.Parameter(mindspore.Tensor(np.random.rand(in_features, out_features), mindspore.float32)) 260 | if bias: 261 | self.bias = mindspore.Parameter(mindspore.Tensor(np.random.rand(out_features), mindspore.float32)) 262 | else: 263 | self.bias = None 264 | 265 | self.reset_parameters() 266 | 267 | def reset_parameters(self): 268 | stdv = 1. / math.sqrt(self.weight.shape[1]) 269 | self.weight.set_data(initializer('normal', self.weight.data.shape)) 270 | if self.bias is not None: 271 | self.bias.set_data(ops.uniform(self.bias.data.shape, mindspore.Tensor(-stdv, mindspore.float32), mindspore.Tensor(stdv, mindspore.float32))) 272 | 273 | def construct(self, inputs, hidden_state, laplacian): 274 | # [x, h] (batch_size, a_dim, input_dim + h_dim) 275 | concat_op = ops.Concat(axis=2) 276 | concatenation = concat_op((inputs, hidden_state)) 277 | batch_size, a_dim, ih_dim = concatenation.shape 278 | # [x, h] (a_dim, batch_size * (input_dim + h_dim)) 279 | concatenation = concatenation.swapaxes(0,1) 280 | concatenation = concatenation.reshape((a_dim, batch_size*ih_dim)) 281 | # concatenation = rearrange(concatenation, 'b a c -> a (b c)') 282 | # A[x, h] (a_dim, batch_size * (input_dim + h_dim)) 283 | a_times_concat = ops.matmul(laplacian, concatenation) 284 | # A[x, h] (batch_size, a_dim, input_dim + h_dim) 285 | a = a_times_concat.shape[0] 286 | a_times_concat = a_times_concat.reshape((a, batch_size, ih_dim)).swapaxes(1,0) 287 | # a_times_concat = rearrange(a_times_concat, 'a (b c) -> b a c', b=batch_size, c=ih_dim) 288 | 289 | output = ops.matmul(a_times_concat, self.weight) 290 | if self.bias is not None: 291 | return output + self.bias 292 | else: 293 | return output 294 | 295 | def __repr__(self): 296 | return self.__class__.__name__ + ' (' \ 297 | + str(self.in_features) + ' -> ' \ 298 | + str(self.out_features) + ')' 299 | 300 | 301 | def MaskConnection(adj, mask_strategy, masked_rate): 302 | if masked_rate > 0.: 303 | drop_matrix = None 304 | if mask_strategy == 'min': 305 | min_ = adj.min() 306 | max_ = adj.max() 307 | q = min_ + (max_ - min_) * masked_rate 308 | drop_matrix = ops.gt(adj, q) 309 | elif mask_strategy == 'random': 310 | q_matrix = ops.dropout(adj, masked_rate) 311 | drop_matrix = ops.gt(q_matrix, 0) 312 | else: 313 | print(f"No such strategy {mask_strategy}.") 314 | adj = ops.mul(drop_matrix, adj) 315 | return adj 316 | -------------------------------------------------------------------------------- /MindSpore/training/r3d.py: -------------------------------------------------------------------------------- 1 | import mindspore 2 | from mindspore import nn 3 | from mindspore.common.initializer import initializer, Zero, HeNormal, One 4 | # from torch.nn import Conv3d 5 | class Conv3DSimple(nn.Conv3d): 6 | def __init__(self, 7 | in_planes, 8 | out_planes, 9 | midplanes=None, 10 | stride=1, 11 | padding=1): 12 | 13 | super(Conv3DSimple, self).__init__( 14 | in_channels=in_planes, 15 | out_channels=out_planes, 16 | kernel_size=(3, 3, 3), 17 | stride=stride, 18 | pad_mode="pad", 19 | padding=padding) 20 | 21 | @staticmethod 22 | def get_downsample_stride(stride): 23 | return (stride, stride, stride) 24 | 25 | 26 | class Conv2Plus1D(nn.SequentialCell): 27 | 28 | def __init__(self, 29 | in_planes, 30 | out_planes, 31 | midplanes, 32 | stride=1, 33 | padding=1): 34 | super(Conv2Plus1D, self).__init__( 35 | nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), 36 | stride=(1, stride, stride), pad_mode="pad", padding=(0, padding, padding)), 37 | nn.BatchNorm3d(midplanes), 38 | nn.ReLU(), 39 | nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), 40 | stride=(stride, 1, 1), pad_mode="pad", padding=(padding, 0, 0))) 41 | 42 | @staticmethod 43 | def get_downsample_stride(stride): 44 | return (stride, stride, stride) 45 | 46 | 47 | class Conv3DNoTemporal(nn.Conv3d): 48 | 49 | def __init__(self, 50 | in_planes, 51 | out_planes, 52 | midplanes=None, 53 | stride=1, 54 | padding=1): 55 | 56 | super(Conv3DNoTemporal, self).__init__( 57 | in_channels=in_planes, 58 | out_channels=out_planes, 59 | kernel_size=(1, 3, 3), 60 | stride=(1, stride, stride), 61 | pad_mode="pad", 62 | padding=(0,0, padding, padding, padding, padding)) 63 | 64 | @staticmethod 65 | def get_downsample_stride(stride): 66 | return (1, stride, stride) 67 | 68 | 69 | class BasicBlock(nn.Cell): 70 | 71 | expansion = 1 72 | 73 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 74 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 75 | 76 | super(BasicBlock, self).__init__() 77 | self.conv1 = nn.SequentialCell( 78 | conv_builder(inplanes, planes, midplanes, stride), 79 | nn.BatchNorm3d(planes), 80 | nn.ReLU() 81 | ) 82 | self.conv2 = nn.SequentialCell( 83 | conv_builder(planes, planes, midplanes), 84 | nn.BatchNorm3d(planes) 85 | ) 86 | self.relu = nn.ReLU() 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def construct(self, x): 91 | residual = x 92 | 93 | out = self.conv1(x) 94 | out = self.conv2(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class Bottleneck(nn.Cell): 106 | expansion = 4 107 | 108 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 109 | 110 | super(Bottleneck, self).__init__() 111 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 112 | 113 | # 1x1x1 114 | self.conv1 = nn.SequentialCell( 115 | nn.Conv3d(inplanes, planes, kernel_size=1), 116 | nn.BatchNorm3d(planes), 117 | nn.ReLU() 118 | ) 119 | # Second kernel 120 | self.conv2 = nn.SequentialCell( 121 | conv_builder(planes, planes, midplanes, stride), 122 | nn.BatchNorm3d(planes), 123 | nn.ReLU() 124 | ) 125 | 126 | # 1x1x1 127 | self.conv3 = nn.SequentialCell( 128 | nn.Conv3d(planes, planes * self.expansion, kernel_size=1), 129 | nn.BatchNorm3d(planes * self.expansion) 130 | ) 131 | self.relu = nn.ReLU() 132 | self.downsample = downsample 133 | self.stride = stride 134 | 135 | def construct(self, x): 136 | residual = x 137 | 138 | out = self.conv1(x) 139 | out = self.conv2(out) 140 | out = self.conv3(out) 141 | 142 | if self.downsample is not None: 143 | residual = self.downsample(x) 144 | 145 | out += residual 146 | out = self.relu(out) 147 | 148 | return out 149 | 150 | 151 | class BasicStem(nn.SequentialCell): 152 | """The default conv-batchnorm-relu stem 153 | """ 154 | def __init__(self): 155 | super(BasicStem, self).__init__( 156 | nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), pad_mode="pad", 157 | padding=(1,1,3, 3, 3, 3)), 158 | nn.BatchNorm3d(64), 159 | nn.ReLU() 160 | # nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 161 | ) 162 | 163 | 164 | class R2Plus1dStem(nn.SequentialCell): 165 | """R(2+1)D stem is different than the default one as it uses separated 3D convolution 166 | """ 167 | def __init__(self): 168 | super(R2Plus1dStem, self).__init__( 169 | nn.Conv3d(3, 45, kernel_size=(1, 7, 7), 170 | stride=(1, 2, 2), pad_mode="pad", padding=(0, 3, 3)), 171 | nn.BatchNorm3d(45), 172 | nn.ReLU(), 173 | nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 174 | stride=(1, 1, 1), pad_mode="pad", padding=(1, 0, 0)), 175 | nn.BatchNorm3d(64), 176 | nn.ReLU()) 177 | 178 | 179 | class VideoResNet(nn.Cell): 180 | 181 | def __init__(self, block, conv_makers, layers, 182 | stem, num_classes=400, 183 | zero_init_residual=False): 184 | """Generic resnet video generator. 185 | 186 | Args: 187 | block (nn.Module): resnet building block 188 | conv_makers (list(functions)): generator function for each layer 189 | layers (List[int]): number of blocks per layer 190 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. 191 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. 192 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. 193 | """ 194 | super(VideoResNet, self).__init__() 195 | self.inplanes = 64 196 | 197 | self.stem = stem() 198 | 199 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) 200 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) 201 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) 202 | self.layer = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) 203 | 204 | # self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 205 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 206 | 207 | # init weights 208 | self._initialize_weights() 209 | 210 | if zero_init_residual: 211 | for m in self.cells(): 212 | if isinstance(m, Bottleneck): 213 | m.bn3.weight.set_data(initializer(Zero())) 214 | 215 | 216 | def construct(self, x): 217 | x = self.stem(x) 218 | 219 | x = self.layer1(x) 220 | x = self.layer2(x) 221 | x = self.layer3(x) 222 | x = self.layer(x) 223 | 224 | # x = self.avgpool(x) 225 | # # Flatten the layer to fc 226 | # x = x.flatten(1) 227 | # x = self.fc(x) 228 | 229 | return x 230 | 231 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1): 232 | downsample = None 233 | 234 | if stride != 1 or self.inplanes != planes * block.expansion: 235 | ds_stride = conv_builder.get_downsample_stride(stride) 236 | downsample = nn.SequentialCell( 237 | nn.Conv3d(self.inplanes, planes * block.expansion, 238 | kernel_size=1, stride=ds_stride), 239 | nn.BatchNorm3d(planes * block.expansion) 240 | ) 241 | layers = [] 242 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) 243 | 244 | self.inplanes = planes * block.expansion 245 | for i in range(1, blocks): 246 | layers.append(block(self.inplanes, planes, conv_builder)) 247 | 248 | return nn.SequentialCell(*layers) 249 | 250 | def _initialize_weights(self): 251 | for m in self.cells(): 252 | if isinstance(m, nn.Conv3d): 253 | m.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'))) 254 | 255 | elif isinstance(m, nn.BatchNorm3d): 256 | m.weight.set_data(initializer(One())) 257 | m.bias.set_data(initializer(Zero())) 258 | 259 | 260 | def _video_resnet(arch, pretrained=False, progress=True, **kwargs): 261 | model = VideoResNet(**kwargs) 262 | 263 | return model 264 | 265 | def r3d_50(pretrained=False, progress=True, **kwargs): 266 | return _video_resnet('r3d_50', 267 | pretrained, progress, 268 | block=Bottleneck, 269 | conv_makers=[Conv3DSimple] * 4, 270 | layers=[3, 4, 6, 3], 271 | stem=BasicStem, ** kwargs) 272 | 273 | def r3d_18(pretrained=False, progress=True, **kwargs): 274 | """Construct 18 layer Resnet3D model as in 275 | https://arxiv.org/abs/1711.11248 276 | 277 | Args: 278 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 279 | progress (bool): If True, displays a progress bar of the download to stderr 280 | 281 | Returns: 282 | nn.Module: R3D-18 network 283 | """ 284 | 285 | return _video_resnet('r3d_18', 286 | pretrained, progress, 287 | block=BasicBlock, 288 | conv_makers=[Conv3DSimple] * 4, 289 | layers=[2, 2, 2, 2], 290 | stem=BasicStem, **kwargs) 291 | 292 | 293 | def mc3_18(pretrained=False, progress=True, **kwargs): 294 | """Constructor for 18 layer Mixed Convolution network as in 295 | https://arxiv.org/abs/1711.11248 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | 301 | Returns: 302 | nn.Module: MC3 Network definition 303 | """ 304 | return _video_resnet('mc3_18', 305 | pretrained, progress, 306 | block=BasicBlock, 307 | conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, 308 | layers=[2, 2, 2, 2], 309 | stem=BasicStem, **kwargs) 310 | 311 | 312 | def rmc3_18(pretrained=False, progress=True, **kwargs): 313 | """Constructor for 18 layer reversed Mixed Convolution network as in 314 | https://arxiv.org/abs/1711.11248 315 | 316 | Args: 317 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 318 | progress (bool): If True, displays a progress bar of the download to stderr 319 | 320 | Returns: 321 | nn.Module: MC3 Network definition 322 | """ 323 | return _video_resnet('mc3_18', 324 | pretrained, progress, 325 | block=BasicBlock, 326 | conv_makers=[Conv3DNoTemporal] + [Conv3DSimple] * 3, 327 | layers=[2, 2, 2, 2], 328 | stem=BasicStem, **kwargs) 329 | 330 | 331 | def r2plus1d_18(pretrained=False, progress=True, **kwargs): 332 | """Constructor for the 18 layer deep R(2+1)D network as in 333 | https://arxiv.org/abs/1711.11248 334 | 335 | Args: 336 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 337 | progress (bool): If True, displays a progress bar of the download to stderr 338 | 339 | Returns: 340 | nn.Module: R(2+1)D-18 network 341 | """ 342 | return _video_resnet('r2plus1d_18', 343 | pretrained, progress, 344 | block=BasicBlock, 345 | conv_makers=[Conv2Plus1D] * 4, 346 | layers=[2, 2, 2, 2], 347 | stem=R2Plus1dStem, **kwargs) 348 | -------------------------------------------------------------------------------- /MindSpore/utils/graph_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mindspore 3 | from mindspore import ops 4 | 5 | def calculate_laplacian_with_self_loop(matrix): 6 | matrix = matrix + ops.eye(ops.shape(matrix)[0], ops.shape(matrix)[1], mindspore.int32) 7 | row_sum = matrix.sum(1) 8 | 9 | d_inv_sqrt = ops.pow(row_sum, -0.5).flatten() 10 | is_inf = ops.IsInf() 11 | 12 | d_inv_sqrt[is_inf(d_inv_sqrt)] = 0.0 13 | 14 | d_mat_inv_sqrt = ops.diag(d_inv_sqrt) # GPU 15 | # d_inv_sqrt = mindspore.numpy.asarray(d_inv_sqrt) 16 | # d_mat_inv_sqrt = mindspore.numpy.diag(d_inv_sqrt) # CPU 17 | 18 | # normalized_laplacian = ( 19 | # matrix.matmul(d_mat_inv_sqrt).swapaxes(0, 1).matmul(d_mat_inv_sqrt) 20 | # ) 21 | normalized_laplacian = ops.matmul(matrix, d_mat_inv_sqrt).swapaxes(0, 1) 22 | normalized_laplacian = ops.matmul(normalized_laplacian, d_mat_inv_sqrt) 23 | 24 | return normalized_laplacian 25 | 26 | def random_edges(dim): 27 | matrix = np.random.rand(dim, dim) 28 | matrix = mindspore.Tensor(matrix,dtype=mindspore.float32) 29 | 30 | # greater = ops.Greater() 31 | # matrix = matrix > 0.5 32 | q_matrix = ops.gt(matrix, 0.5) 33 | zero_like = ops.ZerosLike() 34 | zeros_matrix = zero_like(matrix) 35 | 36 | matrix = ops.masked_fill(zeros_matrix, q_matrix, 1) 37 | 38 | # matrix = matrix.int() 39 | 40 | return matrix -------------------------------------------------------------------------------- /MindSpore/utils/util.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import mindspore 3 | import shutil 4 | 5 | # import matplotlib 6 | # matplotlib.use('agg') 7 | 8 | 9 | def set_lr(opt, new_lr): 10 | for param_group in opt.param_groups: 11 | param_group["lr"] = new_lr 12 | 13 | def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar', best='best.pth.tar'): 14 | mindspore.save_checkpoint(state, join(checkpoint, filename)) 15 | if is_best: 16 | shutil.copyfile(join(checkpoint, filename), join(checkpoint, best)) 17 | 18 | def load_model(model, pretrained): 19 | weights = mindspore.load_checkpoint(pretrained) 20 | epoch = weights['epoch'] 21 | best_auc1 = weights['best_auc1'] 22 | pretrained_dict = weights["state_dict"] 23 | model_dict = model.parameters_dict() 24 | # 1. filter out unnecessary keys 25 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 26 | # 2. overwrite entries in the existing state dict 27 | model_dict.update(pretrained_dict) 28 | # 3. load the new state dict 29 | mindspore.load_param_into_net(model, model_dict) 30 | del weights 31 | return epoch, best_auc1 32 | 33 | def pretrained_model(model, pretrained): 34 | pretrained_dict = mindspore.load_checkpoint(pretrained) 35 | model_dict = model.parameters_dict() 36 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 37 | model_dict.update(pretrained_dict) 38 | mindspore.load_param_into_net(model, model_dict) 39 | 40 | 41 | class AverageMeter(object): 42 | """Computes and stores the average and current value""" 43 | def __init__(self, name, fmt=':f'): 44 | self.name = name 45 | self.fmt = fmt 46 | self.reset() 47 | 48 | def reset(self): 49 | self.val = 0 50 | self.avg = 0 51 | self.sum = 0 52 | self.count = 0 53 | 54 | def update(self, val, n=1): 55 | self.val = val 56 | self.sum += val * n 57 | self.count += n 58 | self.avg = self.sum / self.count 59 | 60 | def __str__(self): 61 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 62 | return fmtstr.format(**self.__dict__) 63 | 64 | class ProgressMeter(object): 65 | def __init__(self, num_batches, meters, prefix=""): 66 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 67 | self.meters = meters 68 | self.prefix = prefix 69 | 70 | def display(self, batch): 71 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 72 | entries += [str(meter) for meter in self.meters] 73 | print(' '.join(entries)) 74 | 75 | def _get_batch_fmtstr(self, num_batches): 76 | num_digits = len(str(num_batches // 1)) 77 | fmt = '{:' + str(num_digits) + 'd}' 78 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 79 | 80 | 81 | ## model parameters 82 | def summary(model): 83 | total_params = sum(p.numel() for p in model.get_parameters()) 84 | trainable_params = sum(p.numel() for p in model.get_parameters() if p.requires_grad) 85 | print('Total - %.2fM' % (total_params/1e6)) 86 | print('Trainable - %.2fM' % (trainable_params/1e6)) 87 | -------------------------------------------------------------------------------- /Pytorch/config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | def __init__(self, datalabel, recipes=[], **params): 3 | self.datalabel = datalabel 4 | self.workers = 8 5 | self.start_epoch = 0 6 | self.epochs = 30 7 | self.mask_strategy = 'min' 8 | self.mask_rate = 0.5 9 | 10 | self.lr = 1e-4 11 | self.scheduler_step = 5 12 | self.scheduler_gamma = 0.5 13 | self.optimizer = 'adamw' 14 | 15 | self.batch_size = 30 16 | self.resize=112 17 | self.a_dim = 8 18 | self.h_dim = 64 19 | self.embed_dim = 768 20 | 21 | self.dropout_rate = 0.0 22 | 23 | for i in recipes: 24 | self.recipe(i) 25 | 26 | for i in params: 27 | self.__setattr__(i, params[i]) 28 | 29 | self.train_dataset = dict(datalabel=self.datalabel, resize=self.resize, augment=self.augment) 30 | self.val_dataset = dict(resize=self.resize, augment='augment_test') 31 | 32 | def recipe(self, name): 33 | if 'r3d' in name: 34 | self.batch_size=30 35 | if 'resnet' in name: 36 | self.batch_size=48 37 | if 'resnet18' in name: 38 | self.batch_size=2048 39 | if 'xception' in name: 40 | self.net='xception' 41 | self.batch_size=64 42 | if 'no-aug' in name: 43 | self.augment='augment_test' 44 | 45 | if 'ff' in name: 46 | self.augment='augment0' 47 | if 'celeb' in name: 48 | self.augment='augment0' 49 | if 'dfdc' in name: 50 | self.augment='augment0' 51 | -------------------------------------------------------------------------------- /Pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations=1.2.1 2 | cudatoolkit=11.6.0 3 | einops=0.5.0 4 | matplotlib=3.5.1 5 | matplotlib-base=3.5.1 6 | mkl=2021.4.0 7 | mkl-service=2.4.0 8 | mkl_fft=1.3.1 9 | mkl_random=1.2.2 10 | numpy=1.22.0 11 | opencv-python=4.7.0.68 12 | pytorch=1.12.0 13 | scikit-image=0.19.3 14 | scikit-learn=1.1.1 15 | scipy=1.7.3 16 | tensorboard=2.8.0 17 | torchvision=0.13.0 18 | -------------------------------------------------------------------------------- /Pytorch/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import argparse 4 | import numpy as np 5 | from sklearn.metrics import roc_auc_score, precision_recall_fscore_support 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from training.dataset import * 11 | from training.network import define_network 12 | from training.evaluate import * 13 | from utils.util import * 14 | from config import Config 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--results', type=str, default='results/') 18 | parser.add_argument('--resume', type=str, default='') 19 | 20 | parser.add_argument('--print_iter', type=int, default=50, help='print frequency') 21 | parser.add_argument('--gpu', type=int, default=None, help='if DDP, set None.') 22 | parser.add_argument('--amp', action='store_true', help='if True, use fp16.') 23 | 24 | def setup_seed(seed): 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | random.seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | 30 | # +---------------------------------------------------------+ 31 | """ 32 | Setup Configuration 33 | """ 34 | config = Config(datalabel="ff-all", recipes=['ff-all', 'r3d']) 35 | 36 | def main(): 37 | args = parser.parse_args() 38 | setup_seed(25) 39 | os.makedirs(args.results, exist_ok=True) 40 | 41 | model = define_network(f_dim=512, h_didm=128, a_dim=12, config=config) 42 | 43 | if args.resume: 44 | weights = torch.load(args.resume) 45 | pretrained_dict = weights['state_dict'] 46 | model.load_state_dict(pretrained_dict) 47 | 48 | if args.gpu is not None: 49 | torch.cuda.set_device(args.gpu) 50 | model = model.cuda(args.gpu) 51 | print("Use Single GPU", args.gpu) 52 | else: 53 | model = torch.nn.DataParallel(model).cuda() 54 | print("Use Data Parallel.") 55 | 56 | criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu) 57 | 58 | val_data = DFDataset(phase='test', datalabel="ff", tag="deepfakes", codec="c23", **config.val_dataset) 59 | ## deepfakes; face2face; faceswap; neural_textures 60 | val_dataloader = DataLoader(val_data, batch_size=2, shuffle=True, num_workers=config.workers, pin_memory=True, drop_last=True) 61 | 62 | auc1 = valid(val_dataloader, model, criterion, 1, args) 63 | 64 | def valid(val_dataloader, model, criterion, ngpus_per_node, args): 65 | batch_time = AverageMeter('Batch', ':2.2f') 66 | data_time = AverageMeter('Data', ':1.2f') 67 | losses = AverageMeter('Loss', ':.2e') 68 | acc = AverageMeter('Acc', ':3.2f') 69 | 70 | progress = ProgressMeter( 71 | len(val_dataloader), 72 | [batch_time, data_time, losses, acc], 73 | prefix='Test:') 74 | 75 | # switch to evaluate mode 76 | model.eval() 77 | 78 | nplabels = [] 79 | softmax_logits = [] 80 | int_logits = [] 81 | 82 | with torch.no_grad(): 83 | end = time.time() 84 | for idx, data in enumerate(val_dataloader): 85 | img, label = data['img'], data['label'] 86 | 87 | _batch_size = label.shape[0] 88 | 89 | _batch_size = img.shape[0] 90 | data_time.update(time.time() - end) 91 | 92 | if args.amp: 93 | img = img.half() # fp16 94 | else: 95 | img = img.float() # fp32 96 | label = label.long() 97 | 98 | if torch.cuda.is_available(): 99 | img = img.cuda(args.gpu, non_blocking=True) 100 | label = label.cuda(args.gpu, non_blocking=True) 101 | 102 | logits = model(img) 103 | 104 | loss_ce = criterion(logits, label) 105 | loss = loss_ce 106 | 107 | pred = torch.nn.functional.softmax(logits, dim=1) 108 | acc1 = accuracy(pred, label) 109 | 110 | softmax_logits.append(logits.softmax(1)[:,1].detach().cpu().numpy()) 111 | nplabels.append(label.cpu().numpy().astype(np.int8)) 112 | int_logits.append(logits.max(1)[1].cpu().numpy().astype(np.int8)) 113 | 114 | losses.update(loss.item(), _batch_size) 115 | acc.update(acc1[0].item(), _batch_size) 116 | 117 | batch_time.update(time.time() - end) 118 | end = time.time() 119 | 120 | epoch_losses = loss.clone().detach().cuda(args.gpu) 121 | epoch_acc = torch.tensor(acc1).cuda(args.gpu) 122 | 123 | epoch_losses = epoch_losses.item() / ngpus_per_node 124 | epoch_acc = epoch_acc.cpu().numpy() / ngpus_per_node 125 | 126 | progress.display(idx) 127 | 128 | y_true = np.concatenate(nplabels) 129 | y_pred = np.concatenate(softmax_logits) 130 | int_pred = np.concatenate(int_logits) 131 | 132 | auc = roc_auc_score(y_true, y_pred) 133 | auc = 100. * auc 134 | 135 | eer1, _ = calculate_eer(y_true, y_pred) 136 | 137 | pre, rec, f1, _ = precision_recall_fscore_support(y_true, int_pred, average='binary') 138 | pre = 100.* pre; rec = 100. * rec; f1 = 100. * f1 139 | 140 | print(" * Acc:{Acc.avg:.3f} Auc:{Auc:.3f} Pre:{Pre:.3f} Rec:{Rec:.3f} F1:{F1:.3f} EER:{Eer:.3f}".format(Acc=acc, Auc=auc, Pre=pre, Rec=rec, F1=f1, Eer=eer1)) 141 | 142 | return auc 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /Pytorch/train_ffpp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from os.path import join 4 | import argparse 5 | 6 | import logging 7 | import numpy as np 8 | from sklearn.metrics import roc_auc_score, precision_recall_fscore_support 9 | import torch 10 | import torch.multiprocessing as mp 11 | import torch.optim as optim 12 | from torch.cuda.amp import autocast, GradScaler 13 | from torch.utils.data import DataLoader 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.utils.data.distributed import DistributedSampler 17 | 18 | from training.dataset import * 19 | from training.network import define_network 20 | from training.evaluate import * 21 | from utils.util import * 22 | from config import Config 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--checkpoint', type=str, default='checkpoint/') 26 | parser.add_argument('--results', type=str, default='results/') 27 | parser.add_argument('--resume', type=str, default=None) 28 | 29 | parser.add_argument('--lambda_bce', type=float, default=1., help="1") 30 | 31 | parser.add_argument('--print_iter', type=int, default=100, help='print frequency') 32 | parser.add_argument('--save_epoch', default=1, type=int) 33 | 34 | parser.add_argument('--amp', action='store_true', help='if True, use fp16.') 35 | parser.add_argument('--local_rank', type=int, default=0) 36 | parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:23504') 37 | parser.add_argument('--world_size', type=int, default=1) 38 | parser.add_argument('--gpu', type=int, default=None, help='if DDP, set None.') 39 | parser.add_argument('--multiprocessing_distributed', type=bool, default=False, 40 | help='Use multi-processing distributed training to launch ' 41 | 'N processes per node, which has N GPUs. This is the ' 42 | 'fastest way to use PyTorch for either single node or ' 43 | 'multi node data parallel training') 44 | # +---------------------------------------------------------+ 45 | 46 | def setup_seed(seed, deterministic=False): 47 | torch.manual_seed(seed) 48 | np.random.seed(seed) 49 | random.seed(seed) 50 | torch.manual_seed(seed) 51 | torch.cuda.manual_seed(seed) 52 | torch.cuda.manual_seed_all(seed) 53 | if deterministic: 54 | torch.backends.cudnn.benchmark = True 55 | 56 | # +---------------------------------------------------------+ 57 | """ 58 | Setup Configuration 59 | """ 60 | config = Config(datalabel='ff-all', recipes=['ff-all','r3d'], ckpt='CKPT_FILE_NAME.pth.tar', best_ckpt='BEST_CKPT_FILE_NAME.pth.tar', epochs=100) 61 | best_auc1 = 0 62 | 63 | def main(): 64 | args = parser.parse_args() 65 | setup_seed(1000) 66 | os.makedirs(args.checkpoint, exist_ok=True) 67 | os.makedirs(args.results, exist_ok=True) 68 | 69 | if args.dist_url == "env://" and args.world_size == -1: 70 | args.world_size = int(os.environ["WORLD_SIZE"]) 71 | 72 | if args.multiprocessing_distributed: 73 | print("args.multiprocessing_distributed==True") 74 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 75 | ngpus_per_node = torch.cuda.device_count() 76 | 77 | if args.multiprocessing_distributed: 78 | args.world_size = ngpus_per_node * args.world_size 79 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 80 | else: 81 | print('No Distributed Data Processing. GPU => ', args.gpu) 82 | main_worker(args.gpu, ngpus_per_node, args) 83 | 84 | def main_worker(gpu, ngpus_per_node, args): 85 | global best_auc1 86 | if args.local_rank == 0: 87 | logging.basicConfig(filename=join(args.results, 'train.log'), filemode='w', format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) 88 | args.gpu = gpu 89 | if args.gpu is not None: 90 | print('Use GPU: {} for training.'.format(args.gpu)) 91 | if args.distributed: 92 | if args.dist_url == "env://" and args.local_rank == -1: 93 | args.local_rank = int(os.environ["RANK"]) 94 | if args.multiprocessing_distributed: 95 | # For multiprocessing distributed training, rank needs to be the 96 | # global rank among all the processes 97 | args.local_rank = args.local_rank * ngpus_per_node + gpu 98 | dist.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.local_rank) 99 | 100 | model = define_network(f_dim=512, h_didm=128, a_dim=12, config=config) 101 | 102 | if args.resume: 103 | if args.local_rank % ngpus_per_node == 0: 104 | print('=> Use previously trained model from {}'.format(args.resume)) 105 | 106 | config.start_epoch, best_auc1 = load_model(model, join(args.resume, f'{config.best_ckpt}')) 107 | 108 | if args.distributed: 109 | if args.gpu is not None: 110 | args.batch_size = int(config.batch_size / ngpus_per_node) 111 | args.workers = int((config.workers + ngpus_per_node -1) / ngpus_per_node) 112 | torch.cuda.set_device(args.gpu) 113 | print('GPU:', args.gpu) 114 | 115 | model = model.cuda(args.gpu) 116 | model = DDP(model, device_ids=[args.gpu]) 117 | else: 118 | ''' 119 | DistributedDataParallel will divide and allocate batch_size to all 120 | available GPUs if device_ids are not set 121 | ''' 122 | model = model.cuda() 123 | 124 | model = DDP(model) 125 | print('Allocate batch-size to all available GPUs') 126 | 127 | elif args.gpu is not None: 128 | # Single GPU. 129 | torch.cuda.set_device(args.gpu) 130 | model = model.cuda(args.gpu) 131 | print("Use Single GPU", args.gpu) 132 | else: 133 | model = torch.nn.DataParallel(model).cuda() 134 | print("Use Data Parallel.") 135 | 136 | opt = optim.AdamW(model.parameters(), lr=config.lr) 137 | 138 | scheduler = optim.lr_scheduler.StepLR(opt, step_size=config.scheduler_step, gamma=config.scheduler_gamma) 139 | 140 | criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu) 141 | 142 | if args.amp: 143 | scaler = GradScaler() 144 | else: 145 | scaler = None 146 | 147 | if config.train_dataset['datalabel'] == 'celeb': 148 | dataset = DFDataset(phase='train', tag="", codec='', **config.train_dataset) 149 | len_dataset = len(dataset) 150 | train_size = int(0.8 * len_dataset) 151 | val_size = len_dataset - train_size 152 | train_data, val_data = torch.utils.data.random_split( 153 | dataset=dataset, 154 | lengths=[train_size, val_size], 155 | generator=torch.Generator().manual_seed(1000) 156 | ) 157 | else: 158 | train_data = DFDataset(phase='train', tag="", codec='c23', **config.train_dataset) 159 | val_data = DFDataset(phase='val', datalabel="ff-all", tag="", codec="c23", **config.val_dataset) 160 | 161 | if args.distributed: 162 | train_sampler = DistributedSampler(train_data, shuffle=True) 163 | val_sampler = DistributedSampler(val_data, shuffle=True) 164 | else: 165 | train_sampler = val_sampler = None 166 | 167 | train_dataloader = DataLoader(train_data, batch_size=config.batch_size, shuffle=(train_sampler is None), num_workers=config.workers, pin_memory=True, drop_last=True, sampler=train_sampler) 168 | val_dataloader = DataLoader(val_data, batch_size=16, shuffle=(val_sampler is None), num_workers=config.workers, pin_memory=True, drop_last=True, sampler=val_sampler) 169 | 170 | for epoch in range(config.start_epoch, config.epochs): 171 | if args.distributed: 172 | train_sampler.set_epoch(epoch) 173 | val_sampler.set_epoch(epoch) 174 | 175 | train(train_dataloader, model, opt, criterion, scaler, epoch, args) 176 | val_auc = valid(val_dataloader, model, criterion, ngpus_per_node, epoch, args) 177 | scheduler.step() 178 | 179 | is_best = val_auc > best_auc1 180 | best_auc1 = max(val_auc, best_auc1) 181 | 182 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.local_rank % ngpus_per_node==0): 183 | save_checkpoint({ 184 | 'epoch': epoch + 1, 185 | 'state_dict': model.state_dict(), 186 | 'best_auc1': best_auc1 187 | }, is_best, args.checkpoint, filename=config.ckpt, best=config.best_ckpt) 188 | 189 | def train(train_dataloader, model, optimizer, criterion, scaler, epoch, args): 190 | batch_time = AverageMeter('Batch', ':2.2f') 191 | data_time = AverageMeter('Data', ':1.2f') 192 | 193 | loss_ce = AverageMeter('CE', ':.3f') 194 | loss_time = AverageMeter('Loss_tc', ':.3f') 195 | loss_org = AverageMeter('Loss_oc', ':.3f') 196 | loss_reg = AverageMeter('Reg', ':.3f') 197 | acc = AverageMeter('Acc', ':3.2f') 198 | 199 | progress = ProgressMeter( 200 | len(train_dataloader), 201 | [batch_time, data_time, loss_ce, loss_time, loss_org, loss_reg, acc], 202 | prefix="Epoch:{}".format(epoch) 203 | ) 204 | 205 | end = time.time() 206 | model.train() 207 | 208 | for idx, data in enumerate(train_dataloader): 209 | 210 | data_time.update(time.time() - end) 211 | 212 | imgs, labels = data['img'], data['label'] 213 | _batch_size = labels.shape[0] 214 | 215 | if args.amp: 216 | imgs = imgs.half() 217 | else: 218 | imgs = imgs.float() 219 | 220 | labels = labels.long() 221 | 222 | if torch.cuda.is_available(): 223 | imgs = imgs.cuda(args.gpu, non_blocking=True) 224 | labels = labels.cuda(args.gpu, non_blocking=True) 225 | 226 | ## Deepfakes Detection branch 227 | optimizer.zero_grad() 228 | with autocast(args.amp): 229 | y_hat, loss_tc, loss_oc = model(imgs) 230 | 231 | reg = model.l2_reg() 232 | 233 | loss_bce = args.lambda_bce * criterion(y_hat, labels) 234 | 235 | loss = loss_bce + reg + 0.75*loss_tc + 0.75*loss_oc 236 | 237 | if args.amp: 238 | scaler.scale(loss).backward() 239 | scaler.step(optimizer) 240 | scaler.update() 241 | else: 242 | loss.backward() 243 | optimizer.step() 244 | 245 | acc1 = accuracy(y_hat, labels) 246 | 247 | acc.update(acc1[0].item(), _batch_size) 248 | loss_ce.update(loss_bce.item(), _batch_size) 249 | loss_time.update(loss_tc.item(), _batch_size) 250 | loss_org.update(loss_oc.item(), _batch_size) 251 | loss_reg.update(reg.item(), _batch_size) 252 | 253 | batch_time.update(time.time() - end) 254 | end = time.time() 255 | 256 | if idx % args.print_iter == 0 and args.local_rank == 0 : 257 | progress.display(idx) 258 | 259 | 260 | def valid(val_dataloader, model, criterion, ngpus_per_node, epoch, args): 261 | batch_time = AverageMeter('Batch', ':2.2f') 262 | data_time = AverageMeter('Data', ':1.2f') 263 | losses = AverageMeter('Loss', ':.2e') 264 | acc = AverageMeter('Acc', ':3.2f') 265 | 266 | progress = ProgressMeter( 267 | len(val_dataloader), 268 | [batch_time, data_time, losses, acc], 269 | prefix='Test:') 270 | 271 | # switch to evaluate mode 272 | model.eval() 273 | 274 | nplabels = [] 275 | softmax_logits = [] 276 | int_logits = [] 277 | 278 | with torch.no_grad(): 279 | end = time.time() 280 | 281 | for idx, data in enumerate(val_dataloader): 282 | img, label = data['img'], data['label'] 283 | 284 | _batch_size = label.shape[0] 285 | 286 | data_time.update(time.time() - end) 287 | 288 | img = img.float() 289 | label = label.long() 290 | 291 | if torch.cuda.is_available(): 292 | img = img.cuda(args.gpu, non_blocking=True) 293 | label = label.cuda(args.gpu, non_blocking=True) 294 | 295 | logits = model(img) 296 | 297 | loss_ce = args.lambda_bce * criterion(logits, label) 298 | loss = loss_ce 299 | 300 | acc1 = accuracy(logits, label) 301 | 302 | nplabels.append(label.cpu().numpy().astype(np.int8)) 303 | int_logits.append(logits.max(1)[1].cpu().numpy().astype(np.int8)) 304 | softmax_logits.append(logits.softmax(1)[:,1].detach().cpu().numpy()) 305 | 306 | losses.update(loss.item(), _batch_size) 307 | acc.update(acc1[0].item(), _batch_size) 308 | 309 | batch_time.update(time.time() - end) 310 | end = time.time() 311 | 312 | epoch_losses = loss.clone().detach().cuda(args.gpu) 313 | epoch_acc = torch.tensor(acc1).cuda(args.gpu) 314 | 315 | if args.distributed: 316 | dist.all_reduce(epoch_losses, op=dist.ReduceOp.SUM) 317 | dist.all_reduce(epoch_acc, op=dist.ReduceOp.SUM) 318 | 319 | epoch_losses = epoch_losses.item() / ngpus_per_node 320 | epoch_acc = epoch_acc.cpu().numpy() / ngpus_per_node 321 | 322 | batch_info = 'Loss:{:.4f} Acc:{:.2f}'.format(epoch_losses, acc.avg) 323 | 324 | if idx % args.print_iter == 0 and args.local_rank == 0: 325 | progress.display(idx) 326 | 327 | y_true = np.concatenate(nplabels) 328 | y_pred = np.concatenate(softmax_logits) 329 | int_pred = np.concatenate(int_logits) 330 | 331 | auc = roc_auc_score(y_true, y_pred) 332 | auc = 100. * auc 333 | pre, rec, f1, _ = precision_recall_fscore_support(y_true, int_pred, average='binary') 334 | pre = 100.* pre; rec = 100. * rec; f1 = 100. * f1 335 | 336 | 337 | if args.local_rank % ngpus_per_node == 0: 338 | logging.info('Train Epoch:{} Time: {} {}'.format(epoch, batch_time, batch_info)) 339 | 340 | print(" * Acc:{Acc.avg:.3f} Auc:{Auc:.3f} Pre:{Pre:.3f} Rec:{Rec:.3f} F1:{F1:.3f}".format(Acc=acc, Auc=auc, Pre=pre, Rec=rec, F1=f1)) 341 | 342 | return acc.avg 343 | 344 | if __name__ == '__main__': 345 | main() 346 | -------------------------------------------------------------------------------- /Pytorch/training/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | 6 | import random 7 | import torch 8 | from torch.utils.data.dataset import Dataset 9 | from albumentations import * 10 | import numpy as np 11 | # Global 12 | seq_len = 100 13 | 14 | def augmentation(type='augment0', resize=112): 15 | augment0 = Compose([Resize(resize, resize),HorizontalFlip(),Normalize(mean=(0.43216,0.394666,0.37645),std=(0.22803,0.22145,0.216989))]) 16 | augment_test = Compose([Resize(resize, resize),Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))], p=1) 17 | augment_set = {'augment0': augment0, 'augment_test': augment_test} 18 | return augment_set[type] 19 | 20 | class DFDataset(Dataset): 21 | def __init__(self, phase, datalabel, resize, tag, codec, augment='augment0'): 22 | assert phase in ['train', 'val', 'test'] 23 | self.datalabel = datalabel 24 | self.resize = resize 25 | self.phase = phase 26 | self.epoch = 0 27 | self.len = 0 28 | self.fake = [] 29 | self.real = [] 30 | self.dataset = [] 31 | self.aug = augmentation(augment, resize) 32 | 33 | if phase == 'train': 34 | print("------Train Set------") 35 | elif phase == 'val': 36 | print("------Validation Set------") 37 | elif phase == 'test': 38 | print("------Test Set------") 39 | else: 40 | print("Error: The phase is None") 41 | if 'ff-all' in self.datalabel: 42 | if tag == "": 43 | print("Load ff-all") 44 | for subtag in ['deepfakes', 'face2face', 'faceswap', 'neural_textures', 'original', 'FaceShifter']: 45 | subdataset = FF_dataset(subtag, codec, phase) 46 | self.dataset += subdataset 47 | if len(subdataset) > 0: 48 | print(f'load {subtag}-{codec} len: {len(subdataset)}') 49 | else: 50 | for subtag in ['deepfakes', 'face2face', 'faceswap', 'neural_textures', 'original']: 51 | if tag != subtag: 52 | subdataset = FF_dataset(subtag, codec, phase) 53 | self.dataset += subdataset 54 | if len(subdataset) > 0: 55 | print(f'load {subtag}-{codec} len: {len(subdataset)}') 56 | if phase != 'test': 57 | self.dataset = make_balance(self.dataset) 58 | 59 | elif 'ff' in self.datalabel: 60 | self.dataset = FF_dataset(tag, codec, phase) 61 | self.dataset += FF_dataset("original", codec, phase) 62 | print(f'load {tag}-{codec} len: {len(self.dataset)}') 63 | 64 | elif 'celeb' in self.datalabel: 65 | self.dataset = CelebDF(phase) 66 | print(f'load {self.datalabel} len: {len(self.dataset)}') 67 | elif 'dfdc' in self.datalabel: 68 | self.dataset = DFDC(phase) 69 | print(f'load {self.datalabel} len: {len(self.dataset)}') 70 | 71 | else: 72 | raise(Exception(f'Error: Dataset {self.datalabel} does not exist!')) 73 | self.len = len(self.dataset) 74 | 75 | def __getitem__(self, index): 76 | fpath_list, label = self.dataset[index] 77 | 78 | # Total number of sampled frames. 79 | len_list = len(fpath_list) 80 | frame_N = len_list 81 | 82 | buffer = np.empty(shape=(seq_len, self.resize, self.resize, 3), dtype=np.float64) 83 | idx = 0 84 | for idx, i in enumerate(range(frame_N)): 85 | fpath = fpath_list[i] 86 | img = cv2.imread(fpath) 87 | img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 88 | img = self.aug(image=img)['image'] 89 | buffer[idx] = img 90 | idx += 1; cur_idx = 0 91 | while idx < seq_len: 92 | buffer[idx] = buffer[cur_idx % frame_N] 93 | cur_idx += 1; idx += 1 94 | buffer = self.ToTensor(buffer) 95 | 96 | return {'img': buffer, 'label': label} 97 | 98 | def __len__(self): 99 | return self.len 100 | 101 | def ToTensor(self, ndarray): 102 | tensor = ndarray.transpose((3, 0, 1, 2)) 103 | return torch.from_numpy(tensor) 104 | 105 | def make_balance(data): 106 | tr = list(filter(lambda x:x[1]==0, data)) 107 | tf = list(filter(lambda x:x[1]==1, data)) 108 | if len(tr) > len(tf): 109 | tr, tf = tf, tr 110 | rate = len(tf) // len(tr) 111 | res = len(tf) - rate * len(tr) 112 | tr = tr * rate + random.sample(tr,res) 113 | return tr + tf 114 | 115 | ## Face Forensics++ 116 | ffpp_raw_train = "/data2/ziming.yang/datasets/ffpp/ffpp_raw_train.json" 117 | ffpp_raw_val = "/data2/ziming.yang/datasets/ffpp/ffpp_raw_val.json" 118 | ffpp_raw_test = "/data2/ziming.yang/datasets/ffpp/ffpp_raw_test.json" 119 | ffpp_c23_train = "/data2/ziming.yang/datasets/ffpp/ffpp_c23_train.json" 120 | ffpp_c23_val = "/data2/ziming.yang/datasets/ffpp/ffpp_c23_val.json" 121 | ffpp_c23_test = "/data2/ziming.yang/datasets/ffpp/ffpp_c23_test.json" 122 | ffpp_c40_train = "/data2/ziming.yang/datasets/ffpp/ffpp_c40_train.json" 123 | ffpp_c40_val = "/data2/ziming.yang/datasets/ffpp/ffpp_c40_val.json" 124 | ffpp_c40_test = "/data2/ziming.yang/datasets/ffpp/ffpp_c40_test.json" 125 | 126 | ffpp_proto = {"raw": {'train': ffpp_raw_train, 'val': ffpp_raw_val, 'test': ffpp_raw_test}, 127 | "c23": {'train': ffpp_c23_train, 'val': ffpp_c23_val, 'test': ffpp_c23_test}, 128 | "c40": {'train': ffpp_c40_train, 'val': ffpp_c40_val, 'test': ffpp_c40_test} 129 | } 130 | def FF_dataset(tag, codec, phase='train'): 131 | assert(tag in ['deepfakes', 'face2face', 'faceswap', 'neural_textures', 'original', 'FaceShifter']) 132 | assert(codec in ['raw','c23','c40','all']) 133 | assert(phase in ['train','val','test','all']) 134 | if phase=="all": 135 | return FF_dataset(tag, codec, 'train') + FF_dataset(tag, codec, 'val') + FF_dataset(tag, codec, 'test') 136 | if codec=="all": 137 | return FF_dataset(tag,'raw', phase) + FF_dataset(tag,'c23',phase) + FF_dataset(tag,'c40',phase) 138 | 139 | _dataset = [] 140 | path = ffpp_proto[codec][phase] 141 | with open(path, 'r') as f: 142 | data_dict = json.load(f) 143 | for k,v in data_dict.items(): 144 | if tag not in k: 145 | continue 146 | video_label = v['label'] 147 | file_list = v['list'] 148 | 149 | for i in range(0, len(file_list), seq_len): 150 | if i+seq_len>=len(file_list): 151 | _dataset.append([file_list[-seq_len:], video_label]) 152 | else: 153 | _dataset.append([file_list[i:i+ seq_len], video_label]) 154 | return _dataset 155 | 156 | ## Celeb-DF v2 157 | celebtrain = "/data2/ziming.yang/datasets/Celeb-DF/celeb_train.json" 158 | celebtest = "/data2/ziming.yang/datasets/Celeb-DF/celeb_test.json" 159 | celeb_proto = {'train': celebtrain, 'test': celebtest} 160 | def CelebDF(phase='train'): 161 | assert(phase in ['train', 'test', 'all']) 162 | if phase=='all': 163 | return CelebDF('train') + CelebDF('tests') 164 | _dataset = [] 165 | path = celeb_proto[phase] 166 | with open(path, 'r') as f: 167 | data_dict = json.load(f) 168 | num_frames = 100 if phase!='train' else None 169 | for k,v in data_dict.items(): 170 | if len(v['list']) != 30: ## Dataset analysis 171 | continue 172 | video_label = v['label'] 173 | file_list = v['list'][:num_frames] 174 | 175 | len_list = len(file_list) 176 | for i in range(0, len_list, seq_len): 177 | if i+seq_len >= len_list: 178 | _dataset.append([file_list[-seq_len:], video_label]) 179 | else: 180 | _dataset.append([file_list[i:i+seq_len], video_label]) 181 | 182 | return _dataset 183 | 184 | ## Deepfakes Detection Challenge 185 | dfdctrain = "/data2/ziming.yang/datasets/DFDC/dfdc_train.json" 186 | dfdcval = "/data2/ziming.yang/datasets/DFDC/dfdc_val.json" 187 | dfdctest = "/data2/ziming.yang/datasets/DFDC/dfdc_test.json" 188 | dfdc_proto = {'train': dfdctrain, 'val': dfdcval, 'test': dfdctest} 189 | def DFDC(phase='train'): 190 | assert(phase in ['train', 'val', 'test', 'all']) 191 | if phase=='all': 192 | return DFDC('train') + DFDC('val') + DFDC('tests') 193 | _dataset = [] 194 | path = dfdc_proto[phase] 195 | num_frames = 100 if phase!='train' else None 196 | with open(path, 'r') as f: 197 | data_dict = json.load(f) 198 | for k,v in data_dict.items(): 199 | if len(v['list']) < 100: 200 | continue 201 | 202 | video_label = v['label'] 203 | 204 | file_list = v['list'][:num_frames] 205 | 206 | len_list = len(file_list) 207 | 208 | for i in range(0, len_list, seq_len): 209 | if i+seq_len >= len_list: 210 | _dataset.append([file_list[-seq_len:], video_label]) 211 | else: 212 | _dataset.append([file_list[i:i+ seq_len], video_label]) 213 | return _dataset -------------------------------------------------------------------------------- /Pytorch/training/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import brentq 3 | from sklearn import metrics 4 | from sklearn.metrics import roc_curve 5 | from scipy.interpolate import interp1d 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | """Computes the accuracy over the k top predictions for the specified values of k""" 9 | with torch.no_grad(): 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 20 | res.append(correct_k.mul_(100.0 / batch_size)) 21 | return res 22 | 23 | def calculate_eer(y_true, y_score): 24 | fpr, tpr, thresholds = roc_curve(y_true, y_score) 25 | eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 26 | thresh = interp1d(fpr, thresholds)(eer) 27 | return eer, thresh 28 | 29 | def compute_video_level_auc(video_to_logits, video_to_labels): 30 | """ " 31 | Compute video-level area under ROC curve. Averages the logits across the video for non-overlapping clips. 32 | 33 | Parameters 34 | ---------- 35 | video_to_logits : dict 36 | Maps video ids to list of logit values 37 | video_to_labels : dict 38 | Maps video ids to label 39 | """ 40 | output_batch = torch.stack( 41 | [torch.mean(torch.stack(video_to_logits[video_id]), 0, keepdim=False) for video_id in video_to_logits.keys()] 42 | ) 43 | output_labels = torch.stack([video_to_labels[video_id] for video_id in video_to_logits.keys()]) 44 | 45 | fpr, tpr, _ = metrics.roc_curve(output_labels.cpu().numpy(), output_batch.cpu().numpy()) 46 | return metrics.auc(fpr, tpr) 47 | 48 | def compute_video_level_acc(video_to_logits, video_to_labels): 49 | output_batch = torch.stack( 50 | [torch.mean(torch.stack(video_to_logits[video_id]), 0, keepdim=False) for video_id in video_to_logits.keys()] 51 | ) 52 | prediction = (output_batch>=0.5).long() 53 | output_labels = torch.stack([video_to_labels[video_id] for video_id in video_to_logits.keys()]) 54 | acc = metrics.accuracy_score(output_labels.cpu().numpy(), prediction.cpu().numpy()) 55 | return acc 56 | 57 | def compute_video_level_prf(video_to_logits, video_to_labels): 58 | output_batch = torch.stack( 59 | [torch.mean(torch.stack(video_to_logits[video_id]), 0, keepdim=False) for video_id in video_to_logits.keys()] 60 | ) 61 | prediction = (output_batch>=0.5).long() 62 | output_labels = torch.stack([video_to_labels[video_id] for video_id in video_to_logits.keys()]) 63 | pre, rec, f1, support = metrics.precision_recall_fscore_support(output_labels.cpu().numpy(), prediction.cpu().numpy(), average='binary') 64 | return pre, rec, f1, support -------------------------------------------------------------------------------- /Pytorch/training/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | from pytorchcv.model_provider import get_model 7 | from timm.models.layers import trunc_normal_ 8 | 9 | from training.r3d import mc3_18 10 | from utils.graph_conv import * 11 | from utils.util import pretrained_model 12 | from einops import rearrange 13 | 14 | def define_network(f_dim, h_didm, a_dim, config): 15 | model = Model(f_dim, h_didm, a_dim, config) 16 | return model 17 | 18 | class MC3(nn.Module): 19 | def __init__(self): 20 | super(MC3, self).__init__() 21 | base = models.video.mc3_18(pretrained=False) 22 | r3d_state = torch.load(pretrained_path['mc3_18']) 23 | base.load_state_dict(r3d_state) 24 | 25 | self.base = nn.Sequential(*list(base.children())[:-1]) 26 | self.fc = nn.Linear(512, 2) 27 | 28 | def forward(self, clip): 29 | x = self.base(clip).flatten(1) 30 | return self.fc(x) 31 | 32 | 33 | class Model(nn.Module): 34 | def __init__(self, f_dim, h_dim, a_dim, config): 35 | super(Model, self).__init__() 36 | self.f_dim = f_dim 37 | self.h_dim = h_dim 38 | self.a_dim = a_dim 39 | self.mask_strategy = config.mask_strategy 40 | self.masked_drop = config.mask_rate 41 | 42 | self.edges = nn.Parameter(random_edges(a_dim)) 43 | 44 | self.enc = Encoder() 45 | 46 | self.atn = AttentionMap(in_channels=f_dim, out_channels=a_dim) 47 | self.drop = nn.Dropout1d(config.dropout_rate) 48 | self.atp = AttentionPooling() # MaxAttentionPooling() 49 | self.drop2d = nn.Dropout(config.dropout_rate) 50 | 51 | self.cell = TGCN(f_dim, h_dim) 52 | 53 | self.bn = nn.BatchNorm1d(a_dim * h_dim) 54 | self.lrelu = nn.ReLU(True) 55 | self.fc = nn.Linear(a_dim * h_dim, 2) 56 | 57 | self.init_weights() 58 | 59 | def init_weights(self): 60 | trunc_normal_(self.fc.weight, std=.02) 61 | nn.init.zeros_(self.fc.bias) 62 | 63 | def forward(self, clip): 64 | adj = self.edges 65 | if self.training and self.mask_strategy!='none': 66 | adj = MaskConnection(adj, self.mask_strategy, self.masked_drop) 67 | adj = calculate_laplacian_with_self_loop(adj) 68 | 69 | batch_size = clip.shape[0] 70 | h1 = torch.zeros(batch_size, self.a_dim, self.h_dim).type_as(clip) 71 | h2 = torch.zeros(batch_size, self.a_dim, self.h_dim).type_as(clip) 72 | 73 | out = None 74 | snippets = torch.chunk(clip, chunks=5, dim=2) 75 | for index, input in enumerate(snippets): 76 | ## Features 77 | feature_map = self.enc(input) 78 | 79 | attention_maps_ = self.atn(feature_map) 80 | 81 | dropout_mask = self.drop(torch.ones([attention_maps_.shape[0], self.a_dim, 1], device=input.device)) 82 | 83 | attention_maps = attention_maps_ * torch.unsqueeze(dropout_mask, -1) 84 | feature_matrix_ = self.atp(feature_map, attention_maps) 85 | feature_matrix = feature_matrix_ * dropout_mask 86 | 87 | h1, h2 = self.cell(feature_matrix, h1, h2, adj) 88 | 89 | if index > 0: 90 | T_attention_maps = torch.cat((T_attention_maps, attention_maps_.unsqueeze(1)), dim=1) 91 | T_att_features = torch.cat((T_att_features, attention_maps_.unsqueeze(1).flatten(-2)), dim=1) 92 | 93 | else: 94 | T_attention_maps = attention_maps_.unsqueeze(1) 95 | T_att_features = attention_maps_.unsqueeze(1).flatten(-2) 96 | 97 | out = h2 98 | T_att_features = T_att_features.flatten(0,1) 99 | 100 | x = out.flatten(1) 101 | 102 | x = self.bn(x) 103 | x = self.lrelu(x) 104 | 105 | logits = self.fc(x) 106 | 107 | if self.training: 108 | return logits, self.loss_tc(T_attention_maps), self.loss_od(T_att_features) 109 | return logits 110 | 111 | def loss_tc(self, T_attention): 112 | batch_size, T, A, H, W = T_attention.size() 113 | loss_tc = 0 114 | for t in range(T-1): 115 | mapi = T_attention[:, t, :, :, :] 116 | mapj = T_attention[:, t+1, :, :, :] 117 | loss_tc += torch.dist(mapi, mapj,p=1) 118 | loss_tc = loss_tc / T / batch_size 119 | return loss_tc 120 | 121 | def loss_od(self, T_att_features): 122 | eps=1e-8 123 | # [B, T, A, H, W] 124 | # B, T, A, C = T_att_features.shape 125 | matrix_A = T_att_features 126 | a_n = matrix_A.norm(dim=2).unsqueeze(2) 127 | # a_n [B, N, 1] 128 | # Normalize 129 | a_norm = matrix_A / torch.max(a_n, eps * torch.ones_like(a_n)) 130 | # patch-wise absolute value of cosine similarity 131 | sim_matrix = torch.einsum('abc,acd->abd', a_norm, a_norm.transpose(1,2)) 132 | loss_rc = sim_matrix.mean() 133 | return loss_rc 134 | 135 | def l2_reg(self): 136 | reg_loss = 0.0 137 | for param in self.cell.parameters(): 138 | reg_loss += torch.sum(param ** 2) / 2 139 | reg_loss = 1.5e-3 * reg_loss 140 | return reg_loss 141 | 142 | def attention(self): 143 | return {'img': self.img[:, :, :, :, :], 'attention': F.interpolate(self.T_attention[:, :, :, :, :], (self.T_attention.shape[-3], self.img.shape[-2], self.img.shape[-1]))} 144 | 145 | 146 | class Encoder(nn.Module): 147 | def __init__(self): 148 | super(Encoder, self).__init__() 149 | 150 | base = mc3_18(pretrained=False) 151 | pretrain_path = pretrained_path['mc3_18'] 152 | 153 | pretrained_model(base, pretrain_path) 154 | 155 | self.base = base 156 | 157 | def forward(self, clip): 158 | return self.base(clip) 159 | 160 | class AttentionMap(nn.Module): 161 | def __init__(self, in_channels, out_channels): 162 | super(AttentionMap, self).__init__() 163 | 164 | self.num_attentions = out_channels 165 | 166 | self.conv_extract = nn.Conv3d(in_channels, in_channels, kernel_size=(7,3,3), stride=(3, 1, 1), padding=1) 167 | self.bn1 = nn.BatchNorm3d(in_channels) 168 | self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), bias=False) 169 | self.bn2 = nn.BatchNorm3d(out_channels) 170 | 171 | def forward(self, x): 172 | if self.num_attentions==0: 173 | return torch.ones([x.shape[0],1,1,1],device=x.device) 174 | x = self.conv_extract(x) 175 | x = self.bn1(x) 176 | x = F.relu(x, inplace=True) 177 | x = self.conv2(x) 178 | x = self.bn2(x) 179 | 180 | x = F.adaptive_avg_pool3d(x, (1, None, None)) 181 | x = x.squeeze(2) 182 | x = F.gelu(x) 183 | return x 184 | 185 | class AttentionPooling(nn.Module): 186 | def __init__(self): 187 | super().__init__() 188 | def forward(self, features, attentions,norm=2): 189 | H, W = features.size()[-2:] 190 | B, M, AH, AW = attentions.size() 191 | if AH != H or AW != W: 192 | attentions=F.interpolate(attentions,size=(H,W), mode='bilinear', align_corners=True) 193 | if norm==1: 194 | attentions=attentions+1e-8 195 | if len(features.shape)==4: 196 | feature_matrix=torch.einsum('imjk,injk->imn', attentions, features) 197 | else: 198 | feature_matrix=torch.einsum('imjk,indjk->imn', attentions, features) 199 | if norm==1: 200 | w=torch.sum(attentions,dim=(2,3)).unsqueeze(-1) 201 | feature_matrix/=w 202 | if norm==2: 203 | feature_matrix = F.normalize(feature_matrix,p=2,dim=-1) 204 | if norm==3: 205 | w=torch.sum(attentions,dim=(2,3)).unsqueeze(-1)+1e-8 206 | feature_matrix/=w 207 | return feature_matrix 208 | 209 | class MaxAttentionPooling(nn.Module): 210 | def __init__(self): 211 | super().__init__() 212 | def forward(self, features, attentions,norm=2): 213 | H, W = features.size()[-2:] 214 | 215 | B, M, AH, AW = attentions.size() 216 | if AH != H or AW != W: 217 | attentions=F.interpolate(attentions,size=(H,W), mode='bilinear', align_corners=True) 218 | if norm==1: 219 | attentions=attentions+1e-8 220 | if len(features.shape)==4: 221 | feature_matrix=torch.einsum('imjk,injk->imnjk', attentions, features) 222 | else: 223 | feature_matrix=torch.einsum('imjk,indjk->imnjk', attentions, features) 224 | if norm==1: 225 | w=torch.sum(attentions,dim=(2,3)).unsqueeze(-1) 226 | feature_matrix/=w 227 | if norm==2: 228 | feature_matrix = F.max_pool3d(feature_matrix, [1,2,2], [1,2,2]) 229 | feature_matrix = F.max_pool3d(feature_matrix, [1,2,2], [1,2,2]) 230 | feature_matrix = feature_matrix.squeeze() 231 | feature_matrix = F.normalize(feature_matrix, p=2, dim=-1) 232 | if norm==3: 233 | w=torch.sum(attentions,dim=(2,3)).unsqueeze(-1)+1e-8 234 | feature_matrix/=w 235 | return feature_matrix 236 | 237 | class TGCN(nn.Module): 238 | def __init__(self, in_dim, h_dim): 239 | super(TGCN, self).__init__() 240 | self.in_dim = in_dim 241 | self.h_dim = h_dim 242 | self.tgcn_layer1 = TGCNCell(self.in_dim, self.h_dim) 243 | self.tgcn_layer2 = TGCNCell(self.h_dim, self.h_dim) 244 | 245 | def forward(self, inputs, H1, H2, adj): 246 | batch_size, a_dim, c = inputs.shape 247 | assert self.in_dim == c 248 | H1 = self.tgcn_layer1(inputs, H1, adj) 249 | H2 = self.tgcn_layer2(H1, H2, adj) 250 | return H1, H2 251 | 252 | class TGCNCell(nn.Module): 253 | def __init__(self, in_dim, h_dim): 254 | super(TGCNCell, self).__init__() 255 | self.in_dim = in_dim 256 | self.h_dim = h_dim 257 | 258 | self.gconv1 = GConv(in_dim + h_dim, h_dim * 2, bias=True) 259 | self.gconv2 = GConv(in_dim + h_dim, h_dim, bias=False) 260 | 261 | def forward(self, inputs, hidden_state, adj): 262 | concatenation = torch.sigmoid(self.gconv1(inputs, hidden_state, adj)) 263 | 264 | r, u = torch.chunk(concatenation, chunks=2, dim=2) 265 | 266 | c = torch.tanh(self.gconv2(inputs, r * hidden_state, adj)) 267 | 268 | new_hidden_state = u * hidden_state + (1.0 - u) * c 269 | 270 | return new_hidden_state 271 | 272 | class GConv(nn.Module): 273 | """ 274 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 275 | """ 276 | def __init__(self, in_features, out_features, bias=True): 277 | super(GConv, self).__init__() 278 | self.in_features = in_features 279 | self.out_features = out_features 280 | 281 | self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) 282 | if bias: 283 | self.bias = nn.Parameter(torch.FloatTensor(out_features)) 284 | else: 285 | self.register_parameter('bias', None) 286 | 287 | self.reset_parameters() 288 | 289 | def reset_parameters(self): 290 | stdv = 1. / math.sqrt(self.weight.size(1)) 291 | nn.init.xavier_normal_(self.weight) 292 | if self.bias is not None: 293 | self.bias.data.uniform_(-stdv, stdv) 294 | 295 | def forward(self, inputs, hidden_state, laplacian): 296 | # [x, h] (batch_size, a_dim, input_dim + h_dim) 297 | concatenation = torch.cat((inputs, hidden_state), dim=2) 298 | batch_size, a_dim, ih_dim = concatenation.shape 299 | # [x, h] (a_dim, batch_size * (input_dim + h_dim)) 300 | concatenation = rearrange(concatenation, 'b a c -> a (b c)') 301 | # A[x, h] (a_dim, batch_size * (input_dim + h_dim)) 302 | a_times_concat = laplacian @ concatenation 303 | # A[x, h] (batch_size, a_dim, input_dim + h_dim) 304 | a_times_concat = rearrange(a_times_concat, 'a (b c) -> b a c', b=batch_size, c=ih_dim) 305 | output = a_times_concat @ self.weight 306 | if self.bias is not None: 307 | return output + self.bias 308 | else: 309 | return output 310 | 311 | def __repr__(self): 312 | return self.__class__.__name__ + ' (' \ 313 | + str(self.in_features) + ' -> ' \ 314 | + str(self.out_features) + ')' 315 | 316 | 317 | class Xception(nn.Module): 318 | def __init__(self): 319 | super(Xception, self).__init__() 320 | base = get_model('Xception', pretrained=False) 321 | pretrain_state = torch.load(pretrained_path['xception']) 322 | base.load_state_dict(pretrain_state) 323 | 324 | base = nn.Sequential(*list(base.children())[:-1]) 325 | base[0].final_block.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1))) 326 | self.base = base 327 | self.flatten = nn.Flatten(1) 328 | self.bn = nn.BatchNorm1d(2048) 329 | self.fc = nn.Linear(2048, 2) 330 | 331 | def forward(self, images): 332 | out = torch.zeros((images.shape[0], 2048), device=images.device) 333 | for img in torch.split(images, 1, 2): 334 | img = img.squeeze(2) 335 | f = self.base(img) 336 | f = self.flatten(f) 337 | out += f 338 | out /= images.shape[2] 339 | out = self.bn(out) 340 | return self.fc(out) 341 | 342 | 343 | pretrained_path = { 344 | 'resnet18': 'pretrained/resnet18.pth', 345 | 'resnet50': 'pretrained/resnet50.pth', 346 | 'xception': 'pretrained/xception.pth', 347 | 'arcface': 'pretrained/model_ir_se50.pth', 348 | 'r3d_18': 'pretrained/r3d_18.pth', 349 | 'r2plus1d_18': 'pretrained/r2plus1d_18.pth', 350 | 'mc3_18': 'pretrained/mc3_18.pth' 351 | } 352 | 353 | def MaskConnection(adj, mask_strategy, masked_rate): 354 | if masked_rate > 0.: 355 | if mask_strategy == 'min': 356 | min_ = adj.min() 357 | max_ = adj.max() 358 | q = min_ + (max_ - min_) * masked_rate 359 | drop_matrix = torch.gt(adj, q) 360 | elif mask_strategy == 'random': 361 | q_matrix = F.dropout(adj, masked_rate) 362 | drop_matrix = torch.gt(q_matrix, 0) 363 | else: 364 | print(f"No such strategy {mask_strategy}.") 365 | adj = drop_matrix * adj 366 | return adj 367 | -------------------------------------------------------------------------------- /Pytorch/training/r3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | try: 4 | from torchvision.models.utils import load_state_dict_from_url 5 | # torchvison < 0.11 6 | except: 7 | from torch.hub import load_state_dict_from_url 8 | # torchvison 0.11 9 | 10 | __all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18'] 11 | 12 | model_urls = { 13 | 'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth', 14 | 'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth', 15 | 'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth', 16 | } 17 | 18 | 19 | class Conv3DSimple(nn.Conv3d): 20 | def __init__(self, 21 | in_planes, 22 | out_planes, 23 | midplanes=None, 24 | stride=1, 25 | padding=1): 26 | 27 | super(Conv3DSimple, self).__init__( 28 | in_channels=in_planes, 29 | out_channels=out_planes, 30 | kernel_size=(3, 3, 3), 31 | stride=stride, 32 | padding=padding, 33 | bias=False) 34 | 35 | @staticmethod 36 | def get_downsample_stride(stride): 37 | return (stride, stride, stride) 38 | 39 | 40 | class Conv2Plus1D(nn.Sequential): 41 | 42 | def __init__(self, 43 | in_planes, 44 | out_planes, 45 | midplanes, 46 | stride=1, 47 | padding=1): 48 | super(Conv2Plus1D, self).__init__( 49 | nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), 50 | stride=(1, stride, stride), padding=(0, padding, padding), 51 | bias=False), 52 | nn.BatchNorm3d(midplanes), 53 | nn.ReLU(inplace=True), 54 | nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), 55 | stride=(stride, 1, 1), padding=(padding, 0, 0), 56 | bias=False)) 57 | 58 | @staticmethod 59 | def get_downsample_stride(stride): 60 | return (stride, stride, stride) 61 | 62 | 63 | class Conv3DNoTemporal(nn.Conv3d): 64 | 65 | def __init__(self, 66 | in_planes, 67 | out_planes, 68 | midplanes=None, 69 | stride=1, 70 | padding=1): 71 | 72 | super(Conv3DNoTemporal, self).__init__( 73 | in_channels=in_planes, 74 | out_channels=out_planes, 75 | kernel_size=(1, 3, 3), 76 | stride=(1, stride, stride), 77 | padding=(0, padding, padding), 78 | bias=False) 79 | 80 | @staticmethod 81 | def get_downsample_stride(stride): 82 | return (1, stride, stride) 83 | 84 | 85 | class BasicBlock(nn.Module): 86 | 87 | expansion = 1 88 | 89 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 90 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 91 | 92 | super(BasicBlock, self).__init__() 93 | self.conv1 = nn.Sequential( 94 | conv_builder(inplanes, planes, midplanes, stride), 95 | nn.BatchNorm3d(planes), 96 | nn.ReLU(inplace=True) 97 | ) 98 | self.conv2 = nn.Sequential( 99 | conv_builder(planes, planes, midplanes), 100 | nn.BatchNorm3d(planes) 101 | ) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | out = self.conv1(x) 110 | out = self.conv2(out) 111 | if self.downsample is not None: 112 | residual = self.downsample(x) 113 | 114 | out += residual 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class Bottleneck(nn.Module): 121 | expansion = 4 122 | 123 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 124 | 125 | super(Bottleneck, self).__init__() 126 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 127 | 128 | # 1x1x1 129 | self.conv1 = nn.Sequential( 130 | nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), 131 | nn.BatchNorm3d(planes), 132 | nn.ReLU(inplace=True) 133 | ) 134 | # Second kernel 135 | self.conv2 = nn.Sequential( 136 | conv_builder(planes, planes, midplanes, stride), 137 | nn.BatchNorm3d(planes), 138 | nn.ReLU(inplace=True) 139 | ) 140 | 141 | # 1x1x1 142 | self.conv3 = nn.Sequential( 143 | nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), 144 | nn.BatchNorm3d(planes * self.expansion) 145 | ) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.downsample = downsample 148 | self.stride = stride 149 | 150 | def forward(self, x): 151 | residual = x 152 | 153 | out = self.conv1(x) 154 | out = self.conv2(out) 155 | out = self.conv3(out) 156 | 157 | if self.downsample is not None: 158 | residual = self.downsample(x) 159 | 160 | out += residual 161 | out = self.relu(out) 162 | 163 | return out 164 | 165 | 166 | class BasicStem(nn.Sequential): 167 | """The default conv-batchnorm-relu stem 168 | """ 169 | def __init__(self): 170 | super(BasicStem, self).__init__( 171 | nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), 172 | padding=(1, 3, 3), bias=False), 173 | nn.BatchNorm3d(64), 174 | nn.ReLU(inplace=True), 175 | # nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 176 | ) 177 | 178 | 179 | class R2Plus1dStem(nn.Sequential): 180 | """R(2+1)D stem is different than the default one as it uses separated 3D convolution 181 | """ 182 | def __init__(self): 183 | super(R2Plus1dStem, self).__init__( 184 | nn.Conv3d(3, 45, kernel_size=(1, 7, 7), 185 | stride=(1, 2, 2), padding=(0, 3, 3), 186 | bias=False), 187 | nn.BatchNorm3d(45), 188 | nn.ReLU(inplace=True), 189 | nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 190 | stride=(1, 1, 1), padding=(1, 0, 0), 191 | bias=False), 192 | nn.BatchNorm3d(64), 193 | nn.ReLU(inplace=True)) 194 | 195 | 196 | class VideoResNet(nn.Module): 197 | 198 | def __init__(self, block, conv_makers, layers, 199 | stem, num_classes=400, 200 | zero_init_residual=False): 201 | """Generic resnet video generator. 202 | 203 | Args: 204 | block (nn.Module): resnet building block 205 | conv_makers (list(functions)): generator function for each layer 206 | layers (List[int]): number of blocks per layer 207 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. 208 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. 209 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. 210 | """ 211 | super(VideoResNet, self).__init__() 212 | self.inplanes = 64 213 | 214 | self.stem = stem() 215 | 216 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) 217 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) 218 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) 219 | self.layer = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) 220 | 221 | # self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 222 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 223 | 224 | # init weights 225 | self._initialize_weights() 226 | 227 | if zero_init_residual: 228 | for m in self.modules(): 229 | if isinstance(m, Bottleneck): 230 | nn.init.constant_(m.bn3.weight, 0) 231 | 232 | def forward(self, x): 233 | x = self.stem(x) 234 | x = self.layer1(x) 235 | x = self.layer2(x) 236 | x = self.layer3(x) 237 | x = self.layer(x) 238 | 239 | # x = self.avgpool(x) 240 | # # Flatten the layer to fc 241 | # x = x.flatten(1) 242 | # x = self.fc(x) 243 | 244 | return x 245 | 246 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1): 247 | downsample = None 248 | 249 | if stride != 1 or self.inplanes != planes * block.expansion: 250 | ds_stride = conv_builder.get_downsample_stride(stride) 251 | downsample = nn.Sequential( 252 | nn.Conv3d(self.inplanes, planes * block.expansion, 253 | kernel_size=1, stride=ds_stride, bias=False), 254 | nn.BatchNorm3d(planes * block.expansion) 255 | ) 256 | layers = [] 257 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) 258 | 259 | self.inplanes = planes * block.expansion 260 | for i in range(1, blocks): 261 | layers.append(block(self.inplanes, planes, conv_builder)) 262 | 263 | return nn.Sequential(*layers) 264 | 265 | def _initialize_weights(self): 266 | for m in self.modules(): 267 | if isinstance(m, nn.Conv3d): 268 | nn.init.kaiming_normal_(m.weight, 269 | mode='fan_out', 270 | nonlinearity='relu') 271 | elif isinstance(m, nn.BatchNorm3d): 272 | nn.init.constant_(m.weight, 1) 273 | nn.init.constant_(m.bias, 0) 274 | 275 | 276 | def _video_resnet(arch, pretrained=False, progress=True, **kwargs): 277 | model = VideoResNet(**kwargs) 278 | 279 | if pretrained: 280 | state_dict = load_state_dict_from_url(model_urls[arch], 281 | progress=progress) 282 | model.load_state_dict(state_dict) 283 | return model 284 | 285 | def r3d_50(pretrained=False, progress=True, **kwargs): 286 | return _video_resnet('r3d_50', 287 | pretrained, progress, 288 | block=Bottleneck, 289 | conv_makers=[Conv3DSimple] * 4, 290 | layers=[3, 4, 6, 3], 291 | stem=BasicStem, ** kwargs) 292 | 293 | def r3d_18(pretrained=False, progress=True, **kwargs): 294 | """Construct 18 layer Resnet3D model as in 295 | https://arxiv.org/abs/1711.11248 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | 301 | Returns: 302 | nn.Module: R3D-18 network 303 | """ 304 | 305 | return _video_resnet('r3d_18', 306 | pretrained, progress, 307 | block=BasicBlock, 308 | conv_makers=[Conv3DSimple] * 4, 309 | layers=[2, 2, 2, 2], 310 | stem=BasicStem, **kwargs) 311 | 312 | 313 | def mc3_18(pretrained=False, progress=True, **kwargs): 314 | """Constructor for 18 layer Mixed Convolution network as in 315 | https://arxiv.org/abs/1711.11248 316 | 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 319 | progress (bool): If True, displays a progress bar of the download to stderr 320 | 321 | Returns: 322 | nn.Module: MC3 Network definition 323 | """ 324 | return _video_resnet('mc3_18', 325 | pretrained, progress, 326 | block=BasicBlock, 327 | conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, 328 | layers=[2, 2, 2, 2], 329 | stem=BasicStem, **kwargs) 330 | 331 | 332 | def rmc3_18(pretrained=False, progress=True, **kwargs): 333 | """Constructor for 18 layer reversed Mixed Convolution network as in 334 | https://arxiv.org/abs/1711.11248 335 | 336 | Args: 337 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 338 | progress (bool): If True, displays a progress bar of the download to stderr 339 | 340 | Returns: 341 | nn.Module: MC3 Network definition 342 | """ 343 | return _video_resnet('mc3_18', 344 | pretrained, progress, 345 | block=BasicBlock, 346 | conv_makers=[Conv3DNoTemporal] + [Conv3DSimple] * 3, 347 | layers=[2, 2, 2, 2], 348 | stem=BasicStem, **kwargs) 349 | 350 | 351 | def r2plus1d_18(pretrained=False, progress=True, **kwargs): 352 | """Constructor for the 18 layer deep R(2+1)D network as in 353 | https://arxiv.org/abs/1711.11248 354 | 355 | Args: 356 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 357 | progress (bool): If True, displays a progress bar of the download to stderr 358 | 359 | Returns: 360 | nn.Module: R(2+1)D-18 network 361 | """ 362 | return _video_resnet('r2plus1d_18', 363 | pretrained, progress, 364 | block=BasicBlock, 365 | conv_makers=[Conv2Plus1D] * 4, 366 | layers=[2, 2, 2, 2], 367 | stem=R2Plus1dStem, **kwargs) 368 | -------------------------------------------------------------------------------- /Pytorch/utils/graph_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def calculate_laplacian_with_self_loop(matrix): 5 | matrix = matrix + torch.eye(matrix.size(0), device=matrix.device) 6 | row_sum = matrix.sum(1) 7 | 8 | d_inv_sqrt = torch.pow(row_sum, -0.5).flatten() 9 | 10 | d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0 11 | 12 | d_mat_inv_sqrt = torch.diag(d_inv_sqrt) 13 | normalized_laplacian = ( 14 | matrix.matmul(d_mat_inv_sqrt).transpose(0, 1).matmul(d_mat_inv_sqrt) 15 | ) 16 | normalized_laplacian = normalized_laplacian.to(matrix.device) 17 | return normalized_laplacian 18 | 19 | def random_edges(dim): 20 | matrix = np.random.rand(dim, dim) 21 | matrix = torch.tensor(matrix,dtype=torch.float32) 22 | matrix = matrix > 0.5 23 | matrix = matrix.int() 24 | 25 | return matrix -------------------------------------------------------------------------------- /Pytorch/utils/util.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import torch 3 | import shutil 4 | 5 | 6 | def set_lr(opt, new_lr): 7 | for param_group in opt.param_groups: 8 | param_group["lr"] = new_lr 9 | 10 | def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar', best='best.pth.tar'): 11 | torch.save(state, join(checkpoint, filename)) 12 | if is_best: 13 | shutil.copyfile(join(checkpoint, filename), join(checkpoint, best)) 14 | 15 | def load_model(model, pretrained): 16 | weights = torch.load(pretrained) 17 | epoch = weights['epoch'] 18 | best_auc1 = weights['best_auc1'] 19 | pretrained_dict = weights["state_dict"] 20 | model_dict = model.state_dict() 21 | # 1. filter out unnecessary keys 22 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 23 | # 2. overwrite entries in the existing state dict 24 | model_dict.update(pretrained_dict) 25 | # 3. load the new state dict 26 | model.load_state_dict(model_dict) 27 | del weights 28 | return epoch, best_auc1 29 | 30 | def pretrained_model(model, pretrained): 31 | pretrained_dict = torch.load(pretrained) 32 | model_dict = model.state_dict() 33 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 34 | model_dict.update(pretrained_dict) 35 | model.load_state_dict(model_dict) 36 | 37 | class AverageMeter(object): 38 | """Computes and stores the average and current value""" 39 | def __init__(self, name, fmt=':f'): 40 | self.name = name 41 | self.fmt = fmt 42 | self.reset() 43 | 44 | def reset(self): 45 | self.val = 0 46 | self.avg = 0 47 | self.sum = 0 48 | self.count = 0 49 | 50 | def update(self, val, n=1): 51 | self.val = val 52 | self.sum += val * n 53 | self.count += n 54 | self.avg = self.sum / self.count 55 | 56 | def __str__(self): 57 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 58 | return fmtstr.format(**self.__dict__) 59 | 60 | class ProgressMeter(object): 61 | def __init__(self, num_batches, meters, prefix=""): 62 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 63 | self.meters = meters 64 | self.prefix = prefix 65 | 66 | def display(self, batch): 67 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 68 | entries += [str(meter) for meter in self.meters] 69 | print(' '.join(entries)) 70 | 71 | def _get_batch_fmtstr(self, num_batches): 72 | num_digits = len(str(num_batches // 1)) 73 | fmt = '{:' + str(num_digits) + 'd}' 74 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 75 | 76 | 77 | ## model parameters 78 | def summary(model): 79 | total_params = sum(p.numel() for p in model.parameters()) 80 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 81 | print('Total - %.2fM' % (total_params/1e6)) 82 | print('Trainable - %.2fM' % (trainable_params/1e6)) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mask Relation 2 | The code of [Masked Relation Learning for DeepFake Detection](https://ieeexplore.ieee.org/document/10054130) (TIFS 2023). 3 | 4 | We provide Pytorch and MindSpore (昇思) versions of source code. 5 | 6 | ### Install 7 | - Python 3.8 8 | - Pytorch 1.12 9 | - MindSpore 1.10.1 10 | - CUDA 11.1 11 | 12 | ### Requirements 13 | ``` 14 | pip install albumentations --user 15 | pip install pytorchcv --user 16 | pip install timm --user 17 | pip install einops --user 18 | ``` 19 | 20 | 21 | ### Train 22 | ``` 23 | python train_ffpp.py --gpu 0 --amp [--multiprocessing_distributed True] 24 | ``` 25 | 26 | ### Test 27 | ``` 28 | python test.py --resume [path/to/checkpoint] 29 | ``` 30 | --------------------------------------------------------------------------------