├── README.md ├── checkpoints └── Ours │ ├── DBBF.yaml │ ├── Model.txt │ └── Ours │ └── test │ ├── Celeb-DF.csv │ ├── DBBF.yaml │ ├── DFDC.csv │ ├── FFIW.csv │ ├── FaceForensics_c23.csv │ └── scores.txt ├── common ├── __init__.py ├── data │ ├── __init__.py │ ├── base_dataloader.py │ └── base_transform.py ├── evaluations │ └── video_evaluation.py ├── losses │ ├── __init__.py │ ├── edl_loss.py │ └── euclidean_loss.py ├── optimizers │ ├── __init__.py │ └── sam.py ├── schedulers │ └── __init__.py ├── task │ ├── __init__.py │ ├── base_task.py │ └── fas │ │ ├── __init__.py │ │ └── modules.py ├── utils │ ├── __init__.py │ ├── cli_utils.py │ ├── distribute_utils.py │ ├── face_utils.py │ ├── grad_cam.py │ ├── logger_utils.py │ ├── map_util.py │ ├── meters.py │ ├── metrics.py │ ├── misc.py │ ├── model_init.py │ └── parameters.py └── visualizer │ ├── html.py │ ├── util.py │ └── visualizer.py ├── configs ├── DBBF.yaml └── backbones │ └── SL-ResNet-50.yaml ├── datasets ├── Image_dataset.py ├── Image_dataset_test.py ├── common │ ├── __init__.py │ └── utils │ │ ├── __init__.py │ │ └── map_util.py ├── factory.py ├── lib │ ├── DeepFakeMask.py │ ├── bi_online_generation.py │ ├── bi_online_generation_sbv.py │ ├── blend.py │ ├── blend_sbv.py │ └── ct │ │ ├── detection │ │ ├── __init__.py │ │ ├── alignment.py │ │ ├── detector.py │ │ └── utils.py │ │ ├── face_alignment │ │ ├── __init__.py │ │ ├── basenet.py │ │ ├── predictor.py │ │ └── utils.py │ │ ├── faster_crop_align_xray.py │ │ ├── operations.py │ │ ├── tracking │ │ ├── __init__.py │ │ ├── sort.py │ │ └── tracker.py │ │ ├── utils.py │ │ └── warp_for_xray.py └── utils │ ├── blend.py │ ├── dataloader_util.py │ ├── funcs.py │ └── initialize.py ├── engine_finetune.py ├── eval.py ├── face_process ├── custom_data_list.py ├── extract.py ├── face_utils.py ├── lib │ ├── common.py │ ├── ct │ │ ├── detection │ │ │ ├── __init__.py │ │ │ ├── alignment.py │ │ │ ├── detector.py │ │ │ └── utils.py │ │ ├── face_alignment │ │ │ ├── __init__.py │ │ │ ├── basenet.py │ │ │ ├── predictor.py │ │ │ └── utils.py │ │ ├── operations.py │ │ ├── tracking │ │ │ ├── __init__.py │ │ │ ├── sort.py │ │ │ └── tracker.py │ │ └── utils.py │ ├── dfdc_utils.py │ ├── shape_predictor_81_face_landmarks.dat │ ├── utils.py │ ├── video_list.py │ └── xray │ │ ├── faster_crop_align_xray.py │ │ └── warp_for_xray.py ├── real_face_process.py ├── real_face_process_Frame.py └── sample_data.py ├── inference.py ├── inference_results.csv ├── models ├── __init__.py ├── custom.py ├── custom_sl.py ├── custom_ssl.py └── lib │ ├── BEiT_v2 │ ├── modeling_finetune.py │ └── modeling_pretrain.py │ ├── BEiT_v3 │ ├── modeling_finetune.py │ ├── modeling_utils.py │ └── utils.py │ ├── DINO │ ├── utils.py │ └── vision_transformer.py │ ├── MAE │ ├── models_mae.py │ └── util │ │ ├── crop.py │ │ ├── datasets.py │ │ ├── hibert.py │ │ ├── lars.py │ │ ├── lr_decay.py │ │ ├── lr_sched.py │ │ ├── misc.py │ │ └── pos_embed.py │ ├── MoCoV3 │ ├── __init__.py │ ├── builder.py │ ├── loader.py │ ├── optimizer.py │ └── vits.py │ ├── SIMMIM │ ├── __init__.py │ ├── build.py │ ├── config.py │ ├── main_finetune.py │ ├── simmim.py │ ├── simmim_finetune__vit_base__img224__800ep.yaml │ ├── swin_transformer.py │ ├── utils.py │ └── vision_transformer.py │ └── obow │ ├── ResNet50_OBoW_full.yaml │ ├── __init__.py │ ├── builder_obow.py │ ├── classification.py │ ├── config.py │ ├── datasets.py │ ├── feature_extractor.py │ ├── fewshot.py │ ├── solver.py │ └── utils.py ├── pretrained_weight └── pretrain_weight.txt ├── requirements.txt ├── test_img ├── Fake1.png ├── Fake1_aligned.png ├── Fake2.png ├── Fake2_aligned.png ├── Fake3.png ├── Fake3_aligned.png ├── Fake4.png ├── Fake4_aligned.png ├── Fake_ori.png ├── Fake_ori_aligned.png ├── Real1.png ├── Real1_aligned.png ├── Real2.png └── Real2_aligned.png ├── train.py ├── train_dualbranch.py └── utils ├── __init__.py ├── aucloss.py ├── ckpt_process.py ├── lr_utils.py ├── metrics.py └── pos_embed.py /checkpoints/Ours/DBBF.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | torch_home: 3 | 4 | method: FaceForensics_c23 5 | compression: c23 6 | checkpoints_dir: checkpoints 7 | name: ${model.label}_${model.name} 8 | exam_dir: ${checkpoints_dir}/${name} 9 | 10 | transform_params: 11 | image_size: 224 12 | mean: [0.485, 0.456, 0.406] 13 | std: [0.229, 0.224, 0.225] 14 | 15 | train: 16 | batch_size: 32 17 | num_workers: 16 18 | print_info_step_freq: 10 19 | save_model_interval: 10 20 | max_epoches: 1 21 | use_warmup: False 22 | warmup_epochs: 0 23 | last_epoch_max_acc: 0.0 24 | 25 | dataset: 26 | name: Image_dataset 27 | params: 28 | root: ../datasets_processed/ 29 | method: ${method} 30 | split: train 31 | num_segments: 8 #8 32 | cutout: True #True 33 | is_sbi: False 34 | image_size: ${transform_params.image_size} 35 | 36 | test: 37 | batch_size: ${train.batch_size} 38 | num_workers: ${train.num_workers} 39 | dataset: 40 | name: Image_dataset 41 | params: 42 | root: ../datasets_processed_8frames/ 43 | method: ${method} #FF-ALL 44 | split: test 45 | num_segments: ${train.dataset.params.num_segments} 46 | cutout: True 47 | is_sbi: False 48 | image_size: ${transform_params.image_size} 49 | 50 | final_test: 51 | batch_size: ${test.batch_size} 52 | num_workers: ${test.num_workers} 53 | dataset: 54 | name: Image_dataset_test 55 | params: 56 | root: ../datasets_processed_8frames/ 57 | method: ALL 58 | split: test 59 | is_sbi: False 60 | num_segments: ${train.dataset.params.num_segments} 61 | image_size: ${transform_params.image_size} 62 | 63 | model: 64 | name: DDBF_BEiT_v2 65 | label: Final 66 | backbone: BEiT_v2 67 | params: 68 | pretrained_path: pretrained_weight/BEiT-1k-Face-55w.tar 69 | image_size: ${transform_params.image_size} 70 | feature_dim: 768 71 | resume: 72 | only_resume_model: False 73 | 74 | 75 | optimizer: 76 | type: N #SAM 77 | name: lamb 78 | params: 79 | lr: 5e-5 80 | opt: ${optimizer.name} 81 | weight_decay: 0.05 #1.0e-5 82 | momentum: 0.9 83 | clip_mode: norm 84 | layer_decay: .75 85 | 86 | loss: 87 | name: EvidenceLoss 88 | params: 89 | num_classes: 2 90 | evidence: exp 91 | loss_type: log 92 | with_kldiv: False 93 | with_avuloss: True 94 | annealing_method: exp 95 | loss2: 96 | name: CrossEntropyLoss 97 | params: 98 | 99 | scheduler: 100 | sched: cosine 101 | lr: ${optimizer.params.lr} 102 | lr_noise_pct: 0.67 103 | lr_noise_std: 1.0 104 | lr_cycle_mul: 1.0 105 | lr_cycle_decay: 0.5 106 | lr_cycle_limit: 1 107 | lr_k_decay: 1.0 108 | warmup_lr: 1e-6 109 | min_lr: 1e-5 # 1e-5 110 | epochs: ${train.max_epoches} 111 | warmup_epochs: 1 112 | cooldown_epochs: 0 #1 #5 -------------------------------------------------------------------------------- /checkpoints/Ours/Ours/test/DBBF.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | torch_home: 3 | 4 | method: FaceForensics_c23 5 | compression: c23 6 | checkpoints_dir: checkpoints 7 | name: ${model.label}_${model.name} 8 | exam_dir: ${checkpoints_dir}/${name} 9 | 10 | transform_params: 11 | image_size: 224 12 | mean: [0.485, 0.456, 0.406] 13 | std: [0.229, 0.224, 0.225] 14 | 15 | train: 16 | batch_size: 32 17 | num_workers: 16 18 | print_info_step_freq: 10 19 | save_model_interval: 10 20 | max_epoches: 1 21 | use_warmup: False 22 | warmup_epochs: 0 23 | last_epoch_max_acc: 0.0 24 | 25 | dataset: 26 | name: Image_dataset 27 | params: 28 | root: ../datasets_processed/ 29 | method: ${method} 30 | split: train 31 | num_segments: 8 #8 32 | cutout: True #True 33 | is_sbi: False 34 | image_size: ${transform_params.image_size} 35 | 36 | test: 37 | batch_size: ${train.batch_size} 38 | num_workers: ${train.num_workers} 39 | dataset: 40 | name: Image_dataset 41 | params: 42 | root: ../datasets_processed_8frames/ 43 | method: ${method} #FF-ALL 44 | split: test 45 | num_segments: ${train.dataset.params.num_segments} 46 | cutout: True 47 | is_sbi: False 48 | image_size: ${transform_params.image_size} 49 | 50 | final_test: 51 | batch_size: ${test.batch_size} 52 | num_workers: ${test.num_workers} 53 | dataset: 54 | name: Image_dataset_test 55 | params: 56 | root: ../datasets_processed_8frames/ 57 | method: ALL 58 | split: test 59 | is_sbi: False 60 | num_segments: ${train.dataset.params.num_segments} 61 | image_size: ${transform_params.image_size} 62 | 63 | model: 64 | name: DDBF_BEiT_v2 65 | label: Final 66 | backbone: BEiT_v2 67 | params: 68 | pretrained_path: pretrained_weight/BEiT-1k-Face-55w.tar 69 | image_size: ${transform_params.image_size} 70 | feature_dim: 768 71 | resume: 72 | only_resume_model: False 73 | 74 | 75 | optimizer: 76 | type: N #SAM 77 | name: lamb 78 | params: 79 | lr: 5e-5 80 | opt: ${optimizer.name} 81 | weight_decay: 0.05 #1.0e-5 82 | momentum: 0.9 83 | clip_mode: norm 84 | layer_decay: .75 85 | 86 | loss: 87 | name: EvidenceLoss 88 | params: 89 | num_classes: 2 90 | evidence: exp 91 | loss_type: log 92 | with_kldiv: False 93 | with_avuloss: True 94 | annealing_method: exp 95 | loss2: 96 | name: CrossEntropyLoss 97 | params: 98 | 99 | scheduler: 100 | sched: cosine 101 | lr: ${optimizer.params.lr} 102 | lr_noise_pct: 0.67 103 | lr_noise_std: 1.0 104 | lr_cycle_mul: 1.0 105 | lr_cycle_decay: 0.5 106 | lr_cycle_limit: 1 107 | lr_k_decay: 1.0 108 | warmup_lr: 1e-6 109 | min_lr: 1e-5 # 1e-5 110 | epochs: ${train.max_epoches} 111 | warmup_epochs: 1 112 | cooldown_epochs: 0 #1 #5 -------------------------------------------------------------------------------- /checkpoints/Ours/Ours/test/scores.txt: -------------------------------------------------------------------------------- 1 | ACC=96.14, AUC=99.36, D=[131 9 18 542], P=98.37, R=96.79, FaceForensics_c23 2 | ACC=83.20, AUC=90.46, D=[141 37 50 290], P=88.69, R=85.29, Celeb-DF 3 | ACC=76.00, AUC=84.90, D=[2113 387 813 1687], P=81.34, R=67.48, DFDC 4 | ACC=81.20, AUC=90.97, D=[202 48 46 204], P=80.95, R=81.60, FFIW -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/common/__init__.py -------------------------------------------------------------------------------- /common/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the following main classes/functions: 3 | - create_base_transforms (function): 4 | create base transforms for training, validation and testing 5 | - create_base_dataloader (function): 6 | create base dataloader for training, validation and testing 7 | """ 8 | from .base_transform import create_base_transforms,create_base_sbi_transforms 9 | from .base_dataloader import create_base_dataloader 10 | -------------------------------------------------------------------------------- /common/data/base_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torch.utils.data import DataLoader 3 | from prefetch_generator import BackgroundGenerator 4 | 5 | class DataLoaderX(DataLoader): 6 | 7 | def __iter__(self): 8 | return BackgroundGenerator(super().__iter__()) 9 | 10 | def create_base_dataloader(args, dataset, split): 11 | """Base data loader 12 | 13 | Args: 14 | args: Dataset config args 15 | split (string): Load "train", "val" or "test" 16 | 17 | Returns: 18 | [dataloader]: Corresponding Dataloader 19 | """ 20 | sampler = None 21 | if args.distributed: 22 | sampler = data.distributed.DistributedSampler(dataset) 23 | 24 | shuffle = True if sampler is None and split == 'train' else False 25 | batch_size = getattr(args, split).batch_size 26 | num_workers = args.num_workers if 'num_workers' in args else 8 27 | drop_last = False if split == 'test' else True 28 | # dataloader = data.DataLoader(dataset, 29 | # batch_size=batch_size, 30 | # shuffle=shuffle, 31 | # sampler=sampler, 32 | # num_workers=num_workers, 33 | # pin_memory=True, 34 | # drop_last=drop_last) 35 | dataloader = DataLoaderX(dataset, 36 | batch_size=batch_size, 37 | shuffle=shuffle, 38 | sampler=sampler, 39 | num_workers=num_workers, 40 | pin_memory=True, 41 | drop_last=drop_last) 42 | return dataloader 43 | -------------------------------------------------------------------------------- /common/data/base_transform.py: -------------------------------------------------------------------------------- 1 | import albumentations as alb 2 | from albumentations.pytorch.transforms import ToTensorV2 3 | 4 | 5 | def create_base_transforms(args, split='train'): 6 | """Base data transformation 7 | 8 | Args: 9 | args: Data transformation args 10 | split (str, optional): Defaults to 'train'. 11 | 12 | Returns: 13 | [transform]: Data transform 14 | """ 15 | num_segments = args.num_segments if 'num_segments' in args else 1 16 | additional_targets = {} 17 | # for i in range(1, num_segments): 18 | # additional_targets[f'image{i}'] = 'image' 19 | if split == 'train': 20 | base_transform = alb.Compose([ 21 | alb.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), 22 | alb.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3), 23 | alb.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3), 24 | alb.ImageCompression(quality_lower=40,quality_upper=100,p=0.5), 25 | alb.HorizontalFlip(), 26 | alb.augmentations.transforms.ToGray(p=0.01), 27 | alb.Resize(args.image_size, args.image_size), 28 | # alb.RandomResizedCrop(args.image_size, args.image_size,scale=(0.2, 1), p=1), 29 | alb.Normalize(mean=args.mean, std=args.std), 30 | ToTensorV2(), 31 | ], additional_targets=additional_targets) 32 | 33 | elif split == 'val': 34 | base_transform = alb.Compose([ 35 | alb.Resize(args.image_size, args.image_size), 36 | alb.Normalize(mean=args.mean, std=args.std), 37 | ToTensorV2(), 38 | ], additional_targets=additional_targets) 39 | 40 | elif split == 'test': 41 | base_transform = alb.Compose([ 42 | alb.Resize(args.image_size, args.image_size), 43 | alb.Normalize(mean=args.mean, std=args.std), 44 | ToTensorV2(), 45 | ], additional_targets=additional_targets) 46 | 47 | return base_transform 48 | 49 | 50 | 51 | def create_base_sbi_transforms(args, split='train'): 52 | """Base data transformation 53 | 54 | Args: 55 | args: Data transformation args 56 | split (str, optional): Defaults to 'train'. 57 | 58 | Returns: 59 | [transform]: Data transform 60 | """ 61 | 62 | if split == 'train': 63 | base_transform = alb.Compose([ 64 | 65 | alb.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), 66 | alb.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3), 67 | alb.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3), 68 | alb.ImageCompression(quality_lower=40,quality_upper=100,p=0.5), 69 | alb.Resize(args.image_size, args.image_size), 70 | alb.Normalize(mean=args.mean, std=args.std), 71 | ToTensorV2(), 72 | ], 73 | additional_targets={f'image1': 'image'}, 74 | p=1.) 75 | 76 | elif split == 'val': 77 | base_transform = alb.Compose([ 78 | alb.Resize(args.image_size, args.image_size), 79 | alb.Normalize(mean=args.mean, std=args.std), 80 | ToTensorV2(), 81 | ]) 82 | 83 | elif split == 'test': 84 | base_transform = alb.Compose([ 85 | alb.Resize(args.image_size, args.image_size), 86 | alb.Normalize(mean=args.mean, std=args.std), 87 | ToTensorV2(), 88 | ]) 89 | 90 | return base_transform 91 | -------------------------------------------------------------------------------- /common/losses/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the loss functions used in the project: 3 | - pytorch official losses 4 | - euclidean_loss 5 | """ 6 | from torch.nn.modules.loss import * 7 | from .euclidean_loss import EuclideanLoss 8 | from .edl_loss import EvidenceLoss -------------------------------------------------------------------------------- /common/losses/euclidean_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class EuclideanLoss(nn.Module): 6 | '''Compute euclidean distance between two tensors 7 | ''' 8 | 9 | def __init__(self, reduction=None): 10 | super(EuclideanLoss, self).__init__() 11 | self.reduction = reduction 12 | 13 | def forward(self, x, y): 14 | 15 | n = x.size(0) 16 | m = n 17 | d = x.size(1) 18 | y = y.unsqueeze(0).expand(n, d) 19 | 20 | x = x.unsqueeze(1).expand(n, m, d) 21 | y = y.unsqueeze(0).expand(n, m, d) 22 | 23 | if self.reduction == 'mean': 24 | return torch.pow(x - y, 2).mean() 25 | 26 | elif self.reduction == 'sum': 27 | return torch.pow(x - y, 2).sum() 28 | -------------------------------------------------------------------------------- /common/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the optimizers used in the project: 3 | - pytorch official optimizers 4 | """ 5 | from torch.optim import * 6 | from common.optimizers.sam import SAM -------------------------------------------------------------------------------- /common/optimizers/sam.py: -------------------------------------------------------------------------------- 1 | # borrowed from 2 | 3 | import torch 4 | 5 | import torch 6 | import torch.nn as nn 7 | from common import optimizers 8 | from timm.optim import create_optimizer_v2, optimizer_kwargs 9 | def disable_running_stats(model): 10 | def _disable(module): 11 | if isinstance(module, nn.BatchNorm2d): 12 | module.backup_momentum = module.momentum 13 | module.momentum = 0 14 | 15 | model.apply(_disable) 16 | 17 | def enable_running_stats(model): 18 | def _enable(module): 19 | if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"): 20 | module.momentum = module.backup_momentum 21 | 22 | model.apply(_enable) 23 | 24 | class SAM(torch.optim.Optimizer): 25 | def __init__(self, params, base_optimizer, rho=0.05, **kwargs): 26 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 27 | defaults = dict(rho=rho, **kwargs) 28 | super(SAM, self).__init__(params, defaults) 29 | self.base_optimizer = create_optimizer_v2(self.param_groups, **kwargs) 30 | self.param_groups = self.base_optimizer.param_groups 31 | 32 | @torch.no_grad() 33 | def first_step(self, zero_grad=False): 34 | grad_norm = self._grad_norm() 35 | for group in self.param_groups: 36 | scale = group["rho"] / (grad_norm + 1e-12) 37 | for p in group["params"]: 38 | if p.grad is None: continue 39 | e_w = p.grad * scale.to(p) 40 | p.add_(e_w) # climb to the local maximum "w + e(w)" 41 | self.state[p]["e_w"] = e_w 42 | 43 | if zero_grad: self.zero_grad() 44 | 45 | @torch.no_grad() 46 | def second_step(self, zero_grad=False): 47 | for group in self.param_groups: 48 | for p in group["params"]: 49 | if p.grad is None: continue 50 | p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" 51 | 52 | self.base_optimizer.step() # do the actual "sharpness-aware" update 53 | 54 | if zero_grad: self.zero_grad() 55 | 56 | @torch.no_grad() 57 | def step(self, closure=None): 58 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 59 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 60 | 61 | self.first_step(zero_grad=True) 62 | closure() 63 | self.second_step() 64 | 65 | def _grad_norm(self): 66 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 67 | norm = torch.norm( 68 | torch.stack([ 69 | p.grad.norm(p=2).to(shared_device) 70 | for group in self.param_groups for p in group["params"] 71 | if p.grad is not None 72 | ]), 73 | p=2 74 | ) 75 | return norm -------------------------------------------------------------------------------- /common/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the schedulers used in the project: 3 | - timm schedulers 4 | """ 5 | from timm.scheduler import * 6 | -------------------------------------------------------------------------------- /common/task/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the following main classes/functions: 3 | - BaseTask (class): 4 | base task for training, validation and testing 5 | """ 6 | from .base_task import BaseTask 7 | -------------------------------------------------------------------------------- /common/task/fas/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the following main classes/functions: 3 | - test_module (class): 4 | function for validation and testing 5 | """ 6 | from .modules import test_module 7 | -------------------------------------------------------------------------------- /common/task/fas/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | import torch 7 | 8 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..')) 9 | from common.utils import * 10 | 11 | 12 | def test_module(model, test_data_loaders, forward_function, device='cuda', distributed=False): 13 | """Test module for Face Anti-spoofing 14 | 15 | Args: 16 | model (nn.module): fas model 17 | test_data_loaders (torch.dataloader): list of test data loaders 18 | forward_function (function): model forward function 19 | device (str, optional): Defaults to 'cuda'. 20 | distributed (bool, optional): whether to use distributed training. Defaults to False. 21 | 22 | Returns: 23 | y_preds (list): predictions 24 | y_trues (list): ground truth labels 25 | """ 26 | prob_dict = {} 27 | label_dict = {} 28 | 29 | y_preds = [] 30 | y_trues = [] 31 | 32 | model.eval() 33 | for loaders in test_data_loaders: 34 | for iter, datas in enumerate(tqdm(loaders)): 35 | with torch.no_grad(): 36 | images = datas[0].to(device) 37 | targets = datas[1].to(device) 38 | map_GT = datas[2].to(device) 39 | img_path = datas[3] 40 | probs = forward_function(images) 41 | 42 | if not distributed: 43 | probs = probs.cpu().data.numpy() 44 | label = targets.cpu().data.numpy() 45 | 46 | for i in range(len(probs)): 47 | # the image of the same video share the same video_path 48 | video_path = img_path[i].rsplit('/', 1)[0] 49 | if (video_path in prob_dict.keys()): 50 | prob_dict[video_path].append(probs[i]) 51 | label_dict[video_path].append(label[i]) 52 | else: 53 | prob_dict[video_path] = [] 54 | label_dict[video_path] = [] 55 | prob_dict[video_path].append(probs[i]) 56 | label_dict[video_path].append(label[i]) 57 | else: 58 | y_preds.extend(probs) 59 | y_trues.extend(targets) 60 | 61 | if not distributed: 62 | y_preds = [] 63 | y_trues = [] 64 | for key in prob_dict.keys(): 65 | # calculate the scores in video-level via averaging the scores of the images from the same videos 66 | avg_single_video_prob = sum(prob_dict[key]) / len(prob_dict[key]) 67 | avg_single_video_label = sum(label_dict[key]) / len(label_dict[key]) 68 | y_preds = np.append(y_preds, avg_single_video_prob) 69 | y_trues = np.append(y_trues, avg_single_video_label) 70 | 71 | return y_preds, y_trues 72 | -------------------------------------------------------------------------------- /common/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the following main classes/functions: 3 | - parameters: (deprecated) 4 | - cli_utils: 5 | client parameters utils 6 | - logger_utils: 7 | logger utils 8 | - distribute_utils: 9 | tensor reduce and gather utils in distributed training 10 | - face_utils: 11 | face crop functions 12 | - misc: 13 | training misc utils 14 | - meters: 15 | training meters 16 | - metrics: 17 | calculate metrics 18 | - model_init: 19 | model weight initialization functions 20 | """ 21 | from .parameters import get_parameters 22 | from .cli_utils import get_params 23 | from .logger_utils import get_logger 24 | from .distribute_utils import reduce_tensor, gather_tensor 25 | from .face_utils import add_face_margin, get_face_box 26 | from .misc import set_seed, setup, init_exam_dir, init_wandb_workspace, save_test_results 27 | from .meters import AverageMeter, ProgressMeter 28 | from .metrics import find_best_threshold, cal_metrics 29 | from .model_init import * 30 | -------------------------------------------------------------------------------- /common/utils/cli_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from omegaconf import OmegaConf 3 | 4 | 5 | def flatten_dict(dictionary, delimiter ='.'): 6 | flat_dict = dict() 7 | for key, value in dictionary.items(): 8 | if isinstance(value, dict): 9 | flatten_value_dict = flatten_dict(value, delimiter) 10 | for k, v in flatten_value_dict.items(): 11 | flat_dict[f"{key}{delimiter}{k}"] = v 12 | else: 13 | flat_dict[key] = value 14 | return flat_dict 15 | 16 | 17 | def nested_set(dic, keys, value): 18 | for key in keys[:-1]: 19 | dic = dic.setdefault(key, {}) 20 | dic[keys[-1]] = value 21 | 22 | 23 | def warn_print(x): 24 | """Print info in the highlighted effect. 25 | 26 | Args: 27 | x (any): Info to print. 28 | """ 29 | x = str(x) 30 | x = "\x1b[33;1m" + x + "\x1b[0m" 31 | print(x) 32 | 33 | 34 | def get_params(to_dict=False, **new_kwargs): 35 | """Parse the parameters from a yaml config file, and allow for param modification in the command line. 36 | The input '-c' or '--config' must be specified. 37 | Three arguments have default presets: distributed, local_rank and world_size. 38 | 39 | Argument priority: cmd args > new_kwargs > dict in config. 40 | 41 | Args: 42 | to_dict (bool, optional): 43 | Whether to return the parsed args in Python Dict type. Defaults to False. 44 | new_kwargs (**kwargs): 45 | Allow for overloading some params by specifying new params here. 46 | """ 47 | parser = argparse.ArgumentParser(add_help=False) 48 | parser.add_argument('-c', '--config', type=str, default='cfg.yaml') 49 | parser.add_argument('--distributed', type=int, default=1) 50 | parser.add_argument('--local_rank', type=int, default=0) 51 | parser.add_argument('--world_size', type=int, default=1) 52 | 53 | # parse the above cmd options 54 | args_tmp = parser.parse_known_args()[0] 55 | args_tmp_dict = vars(args_tmp) 56 | 57 | oc_cfg = OmegaConf.load(args_tmp.config) 58 | # args_tmp_dict.pop('config') 59 | oc_cfg.merge_with(args_tmp_dict) 60 | 61 | # append items from new_kwargs 62 | if new_kwargs: 63 | for k in new_kwargs: 64 | if k in oc_cfg: 65 | warn_print(f'{k} from `new_kwargs` found in original conf, will keep the one in `new_kwargs`') 66 | oc_cfg.merge_with(OmegaConf.create(new_kwargs)) 67 | 68 | oc_cfg_dict = OmegaConf.to_container(oc_cfg, resolve=True) 69 | 70 | oc_cfg_dict_flatten = flatten_dict(oc_cfg_dict) 71 | 72 | # add options from config.yaml to argparse 73 | for k, v in oc_cfg_dict_flatten.items(): 74 | if k in args_tmp_dict: 75 | continue 76 | if isinstance(v, bool): 77 | parser.add_argument('--{}'.format(k), dest=k.replace('.', '___'), 78 | type=lambda x: (str(x).lower() == 'true'), default=v) 79 | elif isinstance(v, list) or isinstance(v, tuple): 80 | parser.add_argument('--{}'.format(k), dest=k.replace('.', '___'), 81 | type=type(v[0]), default=v, nargs='+') 82 | else: 83 | parser.add_argument('--{}'.format(k), dest=k.replace('.', '___'), 84 | type=str if v is None else type(v), default=v) 85 | parser.add_argument('-h', '--help', action='help', help=('show this help message and exit')) 86 | args = parser.parse_args() 87 | 88 | var_args = vars(args) 89 | for k, v in var_args.items(): 90 | # if k == 'config': 91 | # continue 92 | ori_k = k.replace('___', '.') 93 | sub_ks = ori_k.split('.') 94 | nested_set(oc_cfg_dict, sub_ks, v) 95 | 96 | oc_cfg.merge_with(oc_cfg_dict) 97 | 98 | # print(OmegaConf.to_yaml(oc_cfg)) 99 | if to_dict: 100 | return OmegaConf.to_container(oc_cfg, resolve=True) 101 | return oc_cfg 102 | -------------------------------------------------------------------------------- /common/utils/distribute_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | def reduce_tensor(tensor, mean=True): 6 | """Reduce tensor in the distributed settting. 7 | 8 | Args: 9 | tensor (torch.tensor): 10 | Input torch tensor to reduce. 11 | mean (bool, optional): 12 | Whether to apply mean. Defaults to True. 13 | 14 | Returns: 15 | [torch.tensor]: Returned reduced torch tensor or. 16 | """ 17 | rt = tensor.clone() # The function operates in-place. 18 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 19 | if mean: 20 | rt /= dist.get_world_size() 21 | return rt 22 | 23 | 24 | def gather_tensor(inp, world_size=None, dist_=True, to_numpy=False): 25 | """Gather tensor in the distributed setting. 26 | 27 | Args: 28 | inp (torch.tensor): 29 | Input torch tensor to gather. 30 | world_size (int, optional): 31 | Dist world size. Defaults to None. If None, world_size = dist.get_world_size(). 32 | dist_ (bool, optional): 33 | Whether to use all_gather method to gather all the tensors. Defaults to True. 34 | to_numpy (bool, optional): 35 | Whether to return numpy array. Defaults to False. 36 | 37 | Returns: 38 | (torch.tensor || numpy.ndarray): Returned tensor or numpy array. 39 | """ 40 | inp = torch.stack(inp) 41 | if dist_: 42 | if world_size is None: 43 | world_size = dist.get_world_size() 44 | gather_inp = [torch.ones_like(inp) for _ in range(world_size)] 45 | dist.all_gather(gather_inp, inp) 46 | gather_inp = torch.cat(gather_inp) 47 | else: 48 | gather_inp = inp 49 | 50 | if to_numpy: 51 | gather_inp = gather_inp.cpu().numpy() 52 | 53 | return gather_inp 54 | -------------------------------------------------------------------------------- /common/utils/meters.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | ''' 3 | The Class AverageMeter record the metrics during the training process 4 | Examples: 5 | >>> acces = AverageMeter('_Acc', ':.5f') 6 | >>> acc = (prediction == labels).float().mean() 7 | >>> acces.update(acc) 8 | ''' 9 | def __init__(self, name='metric', fmt=':f'): 10 | self.name = name 11 | self.fmt = fmt 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | if self.count == 0: 25 | self.avg = self.sum 26 | else: 27 | self.avg = self.sum / self.count 28 | 29 | def __str__(self): 30 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 31 | return fmtstr.format(**self.__dict__) 32 | 33 | 34 | class ProgressMeter(object): 35 | ''' 36 | The ProgressMeter to record all AverageMeter and print the results 37 | Examples: 38 | >>> acces = AverageMeter('_Acc', ':.5f') 39 | >>> progress = ProgressMeter(epoch_size, [acces]) 40 | >>> progress.display(iterations) 41 | ''' 42 | def __init__(self, num_batches, meters, prefix=""): 43 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 44 | self.meters = meters 45 | self.prefix = prefix 46 | 47 | def display(self, batch): 48 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 49 | entries += [str(meter) for meter in self.meters] 50 | # print('\t'.join(entries)) 51 | return '\t'.join(entries) 52 | 53 | def _get_batch_fmtstr(self, num_batches): 54 | num_digits = len(str(num_batches // 1)) 55 | fmt = '{:' + str(num_digits) + 'd}' 56 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 57 | -------------------------------------------------------------------------------- /common/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from easydict import EasyDict 3 | from sklearn.metrics import roc_curve, auc, confusion_matrix 4 | from scipy.optimize import brentq 5 | from scipy.interpolate import interp1d 6 | 7 | 8 | def find_best_threshold(y_trues, y_preds): 9 | ''' 10 | This function is utilized to find the threshold corresponding to the best ACER 11 | Args: 12 | y_trues (list): the list of the ground-truth labels, which contains the int data 13 | y_preds (list): the list of the predicted results, which contains the float data 14 | ''' 15 | print("Finding best threshold...") 16 | best_thre = 0.5 17 | best_metrics = None 18 | candidate_thres = list(np.unique(np.sort(y_preds))) 19 | for thre in candidate_thres: 20 | metrics = cal_metrics(y_trues, y_preds, threshold=thre) 21 | if best_metrics is None: 22 | best_metrics = metrics 23 | best_thre = thre 24 | elif metrics.ACER < best_metrics.ACER: 25 | best_metrics = metrics 26 | best_thre = thre 27 | print(f"Best threshold is {best_thre}") 28 | return best_thre, best_metrics 29 | 30 | 31 | def cal_metrics(y_trues, y_preds, threshold=0.5): 32 | ''' 33 | This function is utilized to calculate the performance of the methods 34 | Args: 35 | y_trues (list): the list of the ground-truth labels, which contains the int data 36 | y_preds (list): the list of the predicted results, which contains the float data 37 | threshold (float, optional): 38 | 'best': calculate the best results 39 | 'auto': calculate the results corresponding to the thresholds of EER 40 | float: calculate the results of the specific thresholds 41 | ''' 42 | 43 | metrics = EasyDict() 44 | 45 | fpr, tpr, thresholds = roc_curve(y_trues, y_preds) 46 | metrics.AUC = auc(fpr, tpr) 47 | 48 | metrics.EER = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 49 | metrics.Thre = float(interp1d(fpr, thresholds)(metrics.EER)) 50 | 51 | if threshold == 'best': 52 | _, best_metrics = find_best_threshold(y_trues, y_preds) 53 | return best_metrics 54 | 55 | elif threshold == 'auto': 56 | threshold = metrics.Thre 57 | 58 | prediction = (np.array(y_preds) > threshold).astype(int) 59 | 60 | res = confusion_matrix(y_trues, prediction, labels=[0, 1]) 61 | TP, FN = res[0, :] 62 | FP, TN = res[1, :] 63 | metrics.ACC = (TP + TN) / len(y_trues) 64 | 65 | TP_rate = float(TP / (TP + FN)) 66 | TN_rate = float(TN / (TN + FP)) 67 | 68 | metrics.APCER = float(FP / (TN + FP)) 69 | metrics.BPCER = float(FN / (FN + TP)) 70 | metrics.ACER = (metrics.APCER + metrics.BPCER) / 2 71 | 72 | return metrics 73 | -------------------------------------------------------------------------------- /common/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import wandb 4 | import shutil 5 | import time 6 | import datetime 7 | import warnings 8 | import torch 9 | import numpy as np 10 | from torch import Tensor 11 | from typing import Optional, List 12 | from timm.models.layers import DropPath, trunc_normal_ 13 | import torch.nn as nn 14 | 15 | def set_seed(SEED): 16 | """This function set the random seed for the training process 17 | 18 | Args: 19 | SEED (int): the random seed 20 | """ 21 | if SEED: 22 | random.seed(SEED) 23 | np.random.seed(SEED) 24 | torch.manual_seed(SEED) 25 | torch.cuda.manual_seed(SEED) 26 | torch.cuda.manual_seed_all(SEED) 27 | torch.backends.cudnn.deterministic = True 28 | 29 | 30 | def setup(cfg): 31 | if getattr(cfg, 'torch_home', None): 32 | os.environ['TORCH_HOME'] = cfg.torch_home 33 | warnings.filterwarnings("ignore") 34 | seed = cfg.seed 35 | set_seed(seed) 36 | 37 | 38 | def init_exam_dir(cfg): 39 | if cfg.local_rank == 0: 40 | if not os.path.exists(cfg.exam_dir): 41 | os.makedirs(cfg.exam_dir) 42 | ckpt_dir = os.path.join(cfg.exam_dir, 'ckpt') 43 | if not os.path.exists(ckpt_dir): 44 | os.makedirs(ckpt_dir) 45 | train_img_dir = os.path.join(cfg.exam_dir, 'train_img') 46 | if not os.path.exists(train_img_dir): 47 | os.makedirs(train_img_dir) 48 | 49 | 50 | def init_wandb_workspace(cfg): 51 | """This function initializes the wandb workspace 52 | """ 53 | if cfg.wandb.name is None: 54 | cfg.wandb.name = cfg.config.split('/')[-1].replace('.yaml', '') 55 | wandb.init(**cfg.wandb) 56 | allow_val_change = False if cfg.wandb.resume is None else True 57 | wandb.config.update(cfg, allow_val_change) 58 | wandb.save(cfg.config) 59 | if cfg.debug or wandb.run.dir == '/tmp': 60 | cfg.exam_dir = 'wandb/debug' 61 | if os.path.exists(cfg.exam_dir): 62 | shutil.rmtree(cfg.exam_dir) 63 | os.makedirs(cfg.exam_dir, exist_ok=True) 64 | else: 65 | cfg.exam_dir = os.path.dirname(wandb.run.dir) 66 | os.makedirs(os.path.join(cfg.exam_dir, 'ckpts'), exist_ok=True) 67 | return cfg 68 | 69 | 70 | def save_test_results(img_paths, y_preds, y_trues, filename='results.log'): 71 | assert len(y_trues) == len(y_preds) == len(img_paths) 72 | 73 | with open(filename, 'w') as f: 74 | for i in range(len(img_paths)): 75 | print(img_paths[i], end=' ', file=f) 76 | print(y_preds[i], file=f) 77 | print(y_trues[i], end=' ', file=f) 78 | 79 | def _max_by_axis(the_list): 80 | # type: (List[List[int]]) -> List[int] 81 | maxes = the_list[0] 82 | for sublist in the_list[1:]: 83 | for index, item in enumerate(sublist): 84 | maxes[index] = max(maxes[index], item) 85 | return maxes 86 | 87 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 88 | # TODO make this more general 89 | if tensor_list[0].ndim == 3: 90 | # TODO make it support different-sized images 91 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 92 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 93 | batch_shape = [len(tensor_list)] + max_size 94 | b, c, h, w = batch_shape 95 | dtype = tensor_list[0].dtype 96 | device = tensor_list[0].device 97 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 98 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 99 | for img, pad_img, m in zip(tensor_list, tensor, mask): 100 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 101 | m[: img.shape[1], :img.shape[2]] = False 102 | else: 103 | raise ValueError('not supported') 104 | return NestedTensor(tensor, mask) 105 | 106 | class NestedTensor(object): 107 | def __init__(self, tensors, mask: Optional[Tensor]): 108 | self.tensors = tensors 109 | self.mask = mask 110 | 111 | def to(self, device, non_blocking=False): 112 | # type: (Device) -> NestedTensor # noqa 113 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 114 | mask = self.mask 115 | if mask is not None: 116 | assert mask is not None 117 | cast_mask = mask.to(device, non_blocking=non_blocking) 118 | else: 119 | cast_mask = None 120 | return NestedTensor(cast_tensor, cast_mask) 121 | 122 | def record_stream(self, *args, **kwargs): 123 | self.tensors.record_stream(*args, **kwargs) 124 | if self.mask is not None: 125 | self.mask.record_stream(*args, **kwargs) 126 | 127 | def decompose(self): 128 | return self.tensors, self.mask 129 | 130 | def flatten(self): 131 | return NestedTensor(self.tensors.flatten(0,1), self.mask) 132 | 133 | def __repr__(self): 134 | return str(self.tensors) 135 | 136 | 137 | def _init_weights(m): 138 | if isinstance(m, nn.Linear): 139 | trunc_normal_(m.weight, std=.02) 140 | if isinstance(m, nn.Linear) and m.bias is not None: 141 | nn.init.constant_(m.bias, 0) 142 | elif isinstance(m, nn.LayerNorm): 143 | nn.init.constant_(m.bias, 0) 144 | nn.init.constant_(m.weight, 1.0) 145 | 146 | -------------------------------------------------------------------------------- /common/utils/model_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | 4 | 5 | def weights_init_xavier(m): 6 | ''' Xavier initialization ''' 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv2d') != -1: 9 | init.xavier_normal_(m.weight.data, gain=1) 10 | elif classname.find('Linear') != -1: 11 | init.xavier_normal_(m.weight.data, gain=1) 12 | if m.bias is not None: 13 | init.constant_(m.bias.data, 0.0) 14 | elif classname.find('BatchNorm2d') != -1: 15 | init.uniform(m.weight.data, 0.02, 1.0) 16 | init.constant_(m.bias.data, 0.0) 17 | elif classname.find('BatchInstanceNorm2d') != -1: 18 | init.uniform(m.weight.data, 0.02, 1.0) 19 | init.constant_(m.bias.data, 0.0) 20 | elif classname.find('InstanceNorm2d') != -1: 21 | init.uniform(m.weight.data, 0.02, 1.0) 22 | init.constant_(m.bias.data, 0.0) 23 | else: 24 | pass 25 | 26 | 27 | def weights_init_normal(m): 28 | ''' Normal initialization ''' 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv') != -1: 31 | init.uniform(m.weight.data, 0.0, 0.02) 32 | elif classname.find('Linear') != -1: 33 | init.uniform(m.weight.data, 0.0, 0.02) 34 | if m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif classname.find('BatchNorm2d') != -1: 37 | init.uniform(m.weight.data, 0.02, 1.0) 38 | init.constant_(m.bias.data, 0.0) 39 | 40 | 41 | def weights_init_kaiming(m): 42 | ''' Kaiming initialization ''' 43 | classname = m.__class__.__name__ 44 | if classname.find('Conv') != -1: 45 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 46 | elif classname.find('Linear') != -1: 47 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 48 | if m.bias is not None: 49 | init.constant_(m.bias.data, 0.0) 50 | elif classname.find('BatchNorm2d') != -1: 51 | init.uniform(m.weight.data, 0.02, 1.0) 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | 55 | def weights_init_orthogonal(m): 56 | ''' Orthogonal initialization ''' 57 | classname = m.__class__.__name__ 58 | if classname.find('Conv') != -1: 59 | init.orthogonal_(m.weight.data, gain=1) 60 | elif classname.find('Linear') != -1: 61 | init.orthogonal_(m.weight.data, gain=1) 62 | if m.bias is not None: 63 | init.constant_(m.bias.data, 0.0) 64 | elif classname.find('BatchNorm2d') != -1: 65 | init.uniform(m.weight.data, 0.02, 1.0) 66 | init.constant_(m.bias.data, 0.0) 67 | 68 | 69 | def weights_init_orthogonal_rnn(m): 70 | ''' Orthogonal_RNN initialization ''' 71 | classname = m.__class__.__name__ 72 | if classname.find('LSTM') != -1: 73 | init.orthogonal_(m.all_weights[0][0], gain=1) 74 | init.orthogonal_(m.all_weights[0][1], gain=1) 75 | init.constant_(m.all_weights[0][2], 1) 76 | init.constant_(m.all_weights[0][3], 1) 77 | elif classname.find('Linear') != -1: 78 | init.xavier_normal_(m.weight.data, gain=1) 79 | init.constant_(m.bias.data, 0.0) 80 | 81 | 82 | def init_weights(net, init_type='normal'): 83 | if init_type == 'normal': 84 | net.apply(weights_init_normal) 85 | elif init_type == 'xavier': 86 | net.apply(weights_init_xavier) 87 | elif init_type == 'kaiming': 88 | net.apply(weights_init_kaiming) 89 | elif init_type == 'orthogonal': 90 | net.apply(weights_init_orthogonal) 91 | elif init_type == 'orthogonal_rnn': 92 | net.apply(weights_init_orthogonal_rnn) 93 | elif init_type == 'const': 94 | net.apply(weights_init_const) 95 | else: 96 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 97 | 98 | 99 | def init_model(net, restore, init_type, init=True): 100 | """Init models with cuda and weights.""" 101 | # init weights of model 102 | if init: 103 | init_weights(net, init_type) 104 | 105 | # restore model weights 106 | if restore is not None: 107 | if os.path.exists(restore): 108 | 109 | # original saved file with DataParallel 110 | state_dict = torch.load(restore) 111 | # create new OrderedDict that does not contain `module.` 112 | from collections import OrderedDict 113 | new_state_dict = OrderedDict() 114 | for k, v in state_dict.items(): 115 | if 'module' in k: 116 | name = k[7:] # remove `module.` 117 | else: 118 | name = k 119 | new_state_dict[name] = v 120 | # load params 121 | net.load_state_dict(new_state_dict) 122 | 123 | net.restored = True 124 | print("*************Restore model from: {}".format(os.path.abspath(restore))) 125 | else: 126 | # raise ValueError('the path ' + restore +' does not exist') 127 | print('the path ' + restore + ' does not exist') 128 | print('init model') 129 | 130 | if torch.cuda.is_available(): 131 | cudnn.benchmark = True 132 | net.cuda() 133 | 134 | return net 135 | -------------------------------------------------------------------------------- /common/utils/parameters.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from omegaconf import OmegaConf 4 | 5 | 6 | def get_parameters(): 7 | """define the parameter for training 8 | 9 | Args: 10 | --config (string): the path of config files 11 | --distributed (int): train the model in the mode of DDP or Not, default: 1 12 | --local_rank (int): define the rank of this process 13 | --world_size (int): define the Number of GPU 14 | """ 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-c', '--config', type=str, default='configs/train.yaml') 17 | parser.add_argument('--distributed', type=int, default=1) 18 | parser.add_argument('--local_rank', type=int, default=0) 19 | parser.add_argument('--world_size', type=int, default=1) 20 | parser.add_argument('--sync-bn', action='store_true', default=False) 21 | parser.add_argument('--debug', action='store_true', default=False) 22 | args = parser.parse_args() 23 | 24 | _C = OmegaConf.load(args.config) 25 | _C.merge_with(vars(args)) 26 | 27 | if _C.debug: 28 | _C.train.epochs = 2 29 | 30 | return _C 31 | 32 | 33 | if __name__ == '__main__': 34 | args = get_parameters() 35 | print(args) 36 | -------------------------------------------------------------------------------- /common/visualizer/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /common/visualizer/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 25 | else: # if it is a numpy array, do nothing 26 | image_numpy = input_image 27 | return image_numpy.astype(imtype) 28 | 29 | 30 | def diagnose_network(net, name='network'): 31 | """Calculate and print the mean of average absolute(gradients) 32 | 33 | Parameters: 34 | net (torch network) -- Torch network 35 | name (str) -- the name of the network 36 | """ 37 | mean = 0.0 38 | count = 0 39 | for param in net.parameters(): 40 | if param.grad is not None: 41 | mean += torch.mean(torch.abs(param.grad.data)) 42 | count += 1 43 | if count > 0: 44 | mean = mean / count 45 | print(name) 46 | print(mean) 47 | 48 | 49 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 50 | """Save a numpy image to the disk 51 | 52 | Parameters: 53 | image_numpy (numpy array) -- input numpy array 54 | image_path (str) -- the path of the image 55 | """ 56 | 57 | image_pil = Image.fromarray(image_numpy) 58 | h, w, _ = image_numpy.shape 59 | 60 | if aspect_ratio > 1.0: 61 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 62 | if aspect_ratio < 1.0: 63 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 64 | image_pil.save(image_path) 65 | 66 | 67 | def print_numpy(x, val=True, shp=False): 68 | """Print the mean, min, max, median, std, and size of a numpy array 69 | 70 | Parameters: 71 | val (bool) -- if print the values of the numpy array 72 | shp (bool) -- if print the shape of the numpy array 73 | """ 74 | x = x.astype(np.float64) 75 | if shp: 76 | print('shape,', x.shape) 77 | if val: 78 | x = x.flatten() 79 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 80 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 81 | 82 | 83 | def mkdirs(paths): 84 | """create empty directories if they don't exist 85 | 86 | Parameters: 87 | paths (str list) -- a list of directory paths 88 | """ 89 | if isinstance(paths, list) and not isinstance(paths, str): 90 | for path in paths: 91 | mkdir(path) 92 | else: 93 | mkdir(paths) 94 | 95 | 96 | def mkdir(path): 97 | """create a single empty directory if it didn't exist 98 | 99 | Parameters: 100 | path (str) -- a single directory path 101 | """ 102 | if not os.path.exists(path): 103 | os.makedirs(path) 104 | -------------------------------------------------------------------------------- /configs/DBBF.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | torch_home: 3 | 4 | method: FaceForensics_c23 5 | compression: c23 6 | checkpoints_dir: checkpoints 7 | name: ${model.label}_${model.name} 8 | exam_dir: ${checkpoints_dir}/${name} 9 | 10 | transform_params: 11 | image_size: 224 12 | mean: [0.485, 0.456, 0.406] 13 | std: [0.229, 0.224, 0.225] 14 | 15 | train: 16 | batch_size: 32 17 | num_workers: 16 18 | print_info_step_freq: 10 19 | save_model_interval: 10 20 | max_epoches: 10 21 | use_warmup: False 22 | warmup_epochs: 0 23 | last_epoch_max_acc: 0.0 24 | 25 | dataset: 26 | name: Image_dataset 27 | params: 28 | root: ../data/ 29 | method: ${method} 30 | split: train 31 | num_segments: 8 #8 32 | cutout: True #True 33 | is_sbi: False 34 | image_size: ${transform_params.image_size} 35 | 36 | test: 37 | batch_size: ${train.batch_size} 38 | num_workers: ${train.num_workers} 39 | dataset: 40 | name: Image_dataset 41 | params: 42 | root: ../data/ 43 | method: ${method} #FF-ALL 44 | split: test 45 | num_segments: ${train.dataset.params.num_segments} 46 | cutout: True 47 | is_sbi: False 48 | image_size: ${transform_params.image_size} 49 | 50 | final_test: 51 | batch_size: ${test.batch_size} 52 | num_workers: ${test.num_workers} 53 | dataset: 54 | name: Image_dataset_test 55 | params: 56 | root: ../data/ 57 | method: ALL 58 | split: test 59 | is_sbi: False 60 | num_segments: ${train.dataset.params.num_segments} 61 | image_size: ${transform_params.image_size} 62 | 63 | model: 64 | name: DDBF_BEiT_v2 65 | label: Final 66 | backbone: BEiT_v2 67 | params: 68 | pretrained_path: pretrained_weight/BEiT-1k-Face-55w.tar 69 | image_size: ${transform_params.image_size} 70 | feature_dim: 768 71 | resume: 72 | only_resume_model: False 73 | 74 | 75 | optimizer: 76 | type: N #SAM 77 | name: lamb 78 | params: 79 | lr: 5e-5 80 | opt: ${optimizer.name} 81 | weight_decay: 0.05 #1.0e-5 82 | momentum: 0.9 83 | clip_mode: norm 84 | layer_decay: .75 85 | 86 | loss: 87 | name: EvidenceLoss 88 | params: 89 | num_classes: 2 90 | evidence: exp 91 | loss_type: log 92 | with_kldiv: False 93 | with_avuloss: True 94 | annealing_method: exp 95 | loss2: 96 | name: CrossEntropyLoss 97 | params: 98 | 99 | scheduler: 100 | sched: cosine 101 | lr: ${optimizer.params.lr} 102 | lr_noise_pct: 0.67 103 | lr_noise_std: 1.0 104 | lr_cycle_mul: 1.0 105 | lr_cycle_decay: 0.5 106 | lr_cycle_limit: 1 107 | lr_k_decay: 1.0 108 | warmup_lr: 1e-6 109 | min_lr: 1e-5 # 1e-5 110 | epochs: ${train.max_epoches} 111 | warmup_epochs: 1 112 | cooldown_epochs: 0 #1 #5 -------------------------------------------------------------------------------- /configs/backbones/SL-ResNet-50.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | torch_home: 3 | 4 | method: FaceForensics_c23 5 | compression: c23 6 | checkpoints_dir: checkpoints 7 | name: ${model.label}_${model.name}_${model.params.network} 8 | exam_dir: ${checkpoints_dir}/${name} 9 | is_aligned: True 10 | align_type: com 11 | 12 | transform_params: 13 | image_size: 224 14 | mean: [0.485, 0.456, 0.406] 15 | std: [0.229, 0.224, 0.225] 16 | 17 | train: 18 | batch_size: 64 19 | num_workers: 16 20 | print_info_step_freq: 10 21 | save_model_interval: 10 22 | max_epoches: 60 23 | use_warmup: False 24 | warmup_epochs: 0 25 | last_epoch_max_acc: 0.0 26 | 27 | dataset: 28 | name: Image_dataset 29 | params: 30 | root: ../datasets_processed/ 31 | method: ${method} 32 | split: train 33 | num_segments: 8 #8 34 | cutout: True #True 35 | is_sbi: False 36 | image_size: ${transform_params.image_size} 37 | 38 | test: 39 | batch_size: ${train.batch_size} 40 | num_workers: ${train.num_workers} 41 | dataset: 42 | name: Image_dataset 43 | params: 44 | root: ../datasets_processed_8frames/ 45 | method: ${method} #FF-ALL 46 | split: test 47 | num_segments: ${train.dataset.params.num_segments} 48 | cutout: True 49 | is_sbi: False 50 | image_size: ${transform_params.image_size} 51 | 52 | final_test: 53 | batch_size: ${test.batch_size} 54 | num_workers: ${test.num_workers} 55 | dataset: 56 | name: Image_dataset_test 57 | params: 58 | root: ../datasets_processed_8frames/ 59 | methods: ['ALL'] 60 | method: ALL 61 | split: test 62 | is_sbi: False 63 | num_segments: ${train.dataset.params.num_segments} 64 | image_size: ${transform_params.image_size} 65 | 66 | model: 67 | name: ResNet 68 | label: SL-1k 69 | backbone: ${model.name} 70 | params: 71 | network: resnet50 72 | pretrained_path: './pretrained_weight/resnet50_a1_0-14fe96d1.pth' 73 | image_size: ${transform_params.image_size} 74 | feature_dim: 2048 75 | resume: 76 | only_resume_model: False 77 | 78 | 79 | optimizer: 80 | type: N #SAM 81 | name: lamb #AdamW #Adam # SAM #Adam 82 | params: 83 | lr: 1e-4 84 | opt: ${optimizer.name} 85 | weight_decay: 0.05 #1.0e-5 86 | momentum: 0.9 87 | clip_mode: norm 88 | layer_decay: .75 89 | 90 | loss: 91 | name: CrossEntropyLoss 92 | params: 93 | 94 | scheduler: 95 | sched: cosine 96 | lr: ${optimizer.params.lr} 97 | lr_noise_pct: 0.67 98 | lr_noise_std: 1.0 99 | lr_cycle_mul: 1.0 100 | lr_cycle_decay: 0.5 101 | lr_cycle_limit: 1 102 | lr_k_decay: 1.0 103 | warmup_lr: 1e-6 104 | min_lr: 1e-5 105 | epochs: ${train.max_epoches} 106 | warmup_epochs: 1 107 | cooldown_epochs: 5 #1 #5 108 | -------------------------------------------------------------------------------- /datasets/Image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data 5 | import random 6 | from collections import OrderedDict 7 | from common.utils import map_util 8 | from PIL import Image 9 | from datasets.utils import dataloader_util 10 | 11 | class Image_dataset(data.Dataset): 12 | def __init__(self, 13 | root, 14 | method='FaceForensics_c23', 15 | split='train', 16 | num_segments=8, 17 | transform=None, 18 | cutout=False, 19 | is_sbi=False, 20 | image_size=224): 21 | super().__init__() 22 | self.root = root 23 | self.dataset_info = [] 24 | self.method = method 25 | self.split = split 26 | self.num_segments = num_segments 27 | self.transform = transform 28 | self.image_size = image_size 29 | self.is_cutout = cutout 30 | self.is_sbi = is_sbi 31 | self.cutout = map_util.Cutout() 32 | self.parse_dataset_info() 33 | 34 | def get_sampled_idx(self, file_path): 35 | frame_path = os.path.join(file_path, 'frame') 36 | file_ext = '.png' 37 | file_names = [f for f in os.listdir(frame_path) if f.endswith(file_ext)] 38 | all_frame_idxs = np.array([int(f[:-len(file_ext)]) for f in file_names]) 39 | all_frame_idxs.sort() 40 | if len(all_frame_idxs) >= self.num_segments: 41 | step = len(all_frame_idxs) // self.num_segments 42 | sampled_frame_idxs = all_frame_idxs[::step][:self.num_segments] 43 | else: 44 | sampled_frame_idxs = [] 45 | idxs = dataloader_util.check_frame_len(len(all_frame_idxs), self.num_segments) 46 | sampled_frame_idxs = all_frame_idxs[idxs] 47 | sampled_frame_idxs.sort() 48 | return sampled_frame_idxs 49 | 50 | def parse_dataset_info(self): 51 | """Parse the video dataset information""" 52 | dataset_list, data_root = dataloader_util.get_data_list(self.method, self.root, self.split, only_real=False) 53 | print('Number of videos loaded:', len(dataset_list), '\nNumber of frames per video:', self.num_segments) 54 | self.all_list = [] 55 | error = 0 56 | for _, file_info in enumerate(dataset_list): # 3600 57 | file_path, video_label = file_info[0], file_info[1] 58 | file_path = data_root + file_path 59 | if os.path.isdir(file_path): 60 | sampled_frame_idx = self.get_sampled_idx(file_path) 61 | for _, frame_idx in enumerate(sampled_frame_idx): 62 | filename_frame = os.path.join(file_path, 'frame', f"{frame_idx}.png") 63 | video_labels = torch.tensor(video_label) 64 | self.all_list.append((filename_frame, video_labels)) 65 | else: 66 | self.all_list.append((file_path+'/0/0.png', video_label, None)) 67 | error = error+1 68 | print('Successfully loaded frames:', len(self.all_list)) 69 | print('Failed to load frames', str(error)) 70 | random.shuffle(self.all_list) 71 | 72 | def __getitem__(self, index): 73 | flag=True 74 | while flag: 75 | file_info = self.all_list[index] 76 | video_labels = file_info[1] 77 | all_frames = [] 78 | sampled_frame_idxs = [] 79 | filename_frame = file_info[0] 80 | frame = np.asarray(Image.open(filename_frame)) 81 | if self.transform is not None: 82 | tmp_imgs = {"image": frame} 83 | all_frames = self.transform(**tmp_imgs) 84 | all_frames = OrderedDict(sorted(all_frames.items(), key=lambda x: x[0])) 85 | all_frames = list(all_frames.values()) 86 | if self.is_cutout: 87 | all_frames = torch.stack(all_frames) 88 | process_imgs = self.cutout(all_frames) 89 | else: 90 | process_imgs = torch.stack(all_frames) 91 | for i in range(self.num_segments): 92 | sampled_frame_idxs.append(file_info[0]) 93 | flag=False 94 | return {"images":process_imgs, "labels":video_labels, "video_path":filename_frame, "sampled_frame_idxs":sampled_frame_idxs} 95 | def __len__(self): 96 | return len(self.all_list) 97 | -------------------------------------------------------------------------------- /datasets/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/datasets/common/__init__.py -------------------------------------------------------------------------------- /datasets/common/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the following main classes/functions: 3 | - parameters: (deprecated) 4 | - cli_utils: 5 | client parameters utils 6 | - logger_utils: 7 | logger utils 8 | - distribute_utils: 9 | tensor reduce and gather utils in distributed training 10 | - face_utils: 11 | face crop functions 12 | - misc: 13 | training misc utils 14 | - meters: 15 | training meters 16 | - metrics: 17 | calculate metrics 18 | - model_init: 19 | model weight initialization functions 20 | """ 21 | from .parameters import get_parameters 22 | from .cli_utils import get_params 23 | from .logger_utils import get_logger 24 | from .distribute_utils import reduce_tensor, gather_tensor 25 | from .face_utils import add_face_margin, get_face_box 26 | from .misc import set_seed, setup, init_exam_dir, init_wandb_workspace, save_test_results 27 | from .meters import AverageMeter, ProgressMeter 28 | from .metrics import find_best_threshold, cal_metrics 29 | from .model_init import * 30 | -------------------------------------------------------------------------------- /datasets/factory.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from omegaconf import OmegaConf 4 | import torch.utils.data as data 5 | 6 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..', '..', '..')) 7 | from common.data import create_base_transforms, create_base_dataloader,create_base_sbi_transforms 8 | 9 | from .Image_dataset import * 10 | from .Image_dataset_test import * 11 | 12 | from torch.utils.data import DataLoader 13 | from prefetch_generator import BackgroundGenerator 14 | 15 | class DataLoaderX(DataLoader): 16 | 17 | def __iter__(self): 18 | return BackgroundGenerator(super().__iter__()) 19 | 20 | def get_dataloader(args, split): 21 | """Set dataloader. 22 | 23 | Args: 24 | args (object): Args load from get_params function. 25 | split (str): One of ['train', 'test'] 26 | """ 27 | transform = create_base_transforms(args.transform_params, split=split) 28 | dataset_cfg = getattr(args, split).dataset 29 | dataset_params = OmegaConf.to_container(dataset_cfg.params, resolve=True) 30 | dataset_params['transform'] = transform 31 | _dataset = eval(dataset_cfg.name)(**dataset_params) 32 | _dataloader = create_base_dataloader(args, _dataset, split=split) 33 | return _dataloader 34 | 35 | 36 | 37 | def get_final_dataloader(args, split): 38 | """Set dataloader. 39 | 40 | Args: 41 | args (object): Args load from get_params function. 42 | split (str): One of ['train', 'test'] 43 | """ 44 | transform = create_base_transforms(args.transform_params, split=split) 45 | 46 | dataset_cfg = getattr(args, 'final_test').dataset 47 | print('dataset:',dataset_cfg.params) 48 | dataset_params = OmegaConf.to_container(dataset_cfg.params, resolve=True) 49 | dataset_params['transform'] = transform 50 | _dataset = eval(dataset_cfg.name)(**dataset_params) 51 | _dataloader = create_base_dataloader(args, _dataset, split=split) 52 | return _dataloader 53 | 54 | def get_final_image_dataloader(args, split): 55 | """Set dataloader. 56 | 57 | Args: 58 | args (object): Args load from get_params function. 59 | split (str): One of ['train', 'test'] 60 | """ 61 | transform = create_base_sbi_transforms(args.transform_params, split=split) 62 | 63 | dataset_cfg = getattr(args, 'final_test').dataset 64 | dataset_params = OmegaConf.to_container(dataset_cfg.params, resolve=True) 65 | dataset_params['transform'] = transform 66 | 67 | _dataset = eval(dataset_cfg.name)(**dataset_params) 68 | 69 | _dataloader = create_base_dataloader(args, _dataset, split=split) 70 | 71 | return _dataloader 72 | 73 | def get_sbi_dataloader(args, split): 74 | """Set dataloader. 75 | 76 | Args: 77 | args (object): Args load from get_params function. 78 | split (str): One of ['train', 'test'] 79 | """ 80 | transform = create_base_transforms(args.transform_params, split=split) 81 | dataset_cfg = getattr(args, split).dataset 82 | dataset_params = OmegaConf.to_container(dataset_cfg.params, resolve=True) 83 | dataset_params['transform'] = transform 84 | _dataset = eval(dataset_cfg.name)(**dataset_params) 85 | _dataloader = create_sbi_dataloader(args, _dataset, split=split) 86 | return _dataloader 87 | 88 | def create_sbi_dataloader(args, dataset, split): 89 | """Base data loader 90 | 91 | Args: 92 | args: Dataset config args 93 | split (string): Load "train", "val" or "test" 94 | 95 | Returns: 96 | [dataloader]: Corresponding Dataloader 97 | """ 98 | sampler = None 99 | if args.distributed: 100 | sampler = data.distributed.DistributedSampler(dataset) 101 | shuffle = True if sampler is None and split == 'train' else False 102 | batch_size = getattr(args, split).batch_size 103 | num_workers = args.num_workers if 'num_workers' in args else 8 104 | drop_last = False if split == 'test' else True 105 | dataloader = DataLoaderX(dataset, 106 | batch_size=batch_size//2, 107 | shuffle=shuffle, 108 | sampler=sampler, 109 | num_workers=num_workers, 110 | pin_memory=True, 111 | drop_last=drop_last, 112 | collate_fn=dataset.collate_fn,) 113 | return dataloader 114 | 115 | 116 | -------------------------------------------------------------------------------- /datasets/lib/blend.py: -------------------------------------------------------------------------------- 1 | # Created by: Kaede Shiohara 2 | # Yamasaki Lab at The University of Tokyo 3 | # shiohara@cvm.t.u-tokyo.ac.jp 4 | # Copyright (c) 2021 5 | # 3rd party softwares' licenses are noticed at https://github.com/mapooon/SelfBlendedImages/blob/master/LICENSE 6 | 7 | import cv2 8 | import numpy as np 9 | import scipy as sp 10 | from skimage.measure import label, regionprops 11 | import random 12 | from PIL import Image 13 | import sys 14 | 15 | 16 | 17 | def alpha_blend(source,target,mask): 18 | mask_blured = get_blend_mask(mask) 19 | img_blended=(mask_blured * source + (1 - mask_blured) * target) 20 | return img_blended,mask_blured 21 | 22 | def dynamic_blend(source,target,mask): 23 | mask_blured = get_blend_mask(mask) 24 | blend_list=[0.25,0.5,0.75,1,1,1] 25 | blend_ratio = blend_list[np.random.randint(len(blend_list))] 26 | mask_blured*=blend_ratio 27 | img_blended=(mask_blured * source + (1 - mask_blured) * target) 28 | return img_blended,mask_blured 29 | 30 | def get_blend_mask(mask): 31 | H,W=mask.shape 32 | size_h=np.random.randint(192,257) 33 | size_w=np.random.randint(192,257) 34 | mask=cv2.resize(mask,(size_w,size_h)) 35 | kernel_1=random.randrange(5,26,2) 36 | kernel_1=(kernel_1,kernel_1) 37 | kernel_2=random.randrange(5,26,2) 38 | kernel_2=(kernel_2,kernel_2) 39 | 40 | mask_blured = cv2.GaussianBlur(mask, kernel_1, 0) 41 | mask_blured = mask_blured/(mask_blured.max()) 42 | mask_blured[mask_blured<1]=0 43 | 44 | mask_blured = cv2.GaussianBlur(mask_blured, kernel_2, np.random.randint(5,46)) 45 | mask_blured = mask_blured/(mask_blured.max()) 46 | mask_blured = cv2.resize(mask_blured,(W,H)) 47 | return mask_blured.reshape((mask_blured.shape+(1,))) 48 | 49 | 50 | def get_alpha_blend_mask(mask): 51 | kernel_list=[(11,11),(9,9),(7,7),(5,5),(3,3)] 52 | blend_list=[0.25,0.5,0.75] 53 | kernel_idxs=random.choices(range(len(kernel_list)), k=2) 54 | blend_ratio = blend_list[random.sample(range(len(blend_list)), 1)[0]] 55 | mask_blured = cv2.GaussianBlur(mask, kernel_list[0], 0) 56 | # print(mask_blured.max()) 57 | mask_blured[mask_blured0]=1 59 | # mask_blured = mask 60 | mask_blured = cv2.GaussianBlur(mask_blured, kernel_list[kernel_idxs[1]], 0) 61 | mask_blured = mask_blured/(mask_blured.max()) 62 | return mask_blured.reshape((mask_blured.shape+(1,))) 63 | 64 | -------------------------------------------------------------------------------- /datasets/lib/blend_sbv.py: -------------------------------------------------------------------------------- 1 | # Created by: Kaede Shiohara 2 | # Yamasaki Lab at The University of Tokyo 3 | # shiohara@cvm.t.u-tokyo.ac.jp 4 | # Copyright (c) 2021 5 | # 3rd party softwares' licenses are noticed at https://github.com/mapooon/SelfBlendedImages/blob/master/LICENSE 6 | 7 | import cv2 8 | import numpy as np 9 | import scipy as sp 10 | from skimage.measure import label, regionprops 11 | import random 12 | from PIL import Image 13 | import sys 14 | 15 | 16 | 17 | def alpha_blend(source,target,mask): 18 | mask_blured = get_blend_mask(mask) 19 | img_blended=(mask_blured * source + (1 - mask_blured) * target) 20 | return img_blended,mask_blured 21 | 22 | def dynamic_blend(source,target,mask): 23 | mask_blured = get_blend_mask(mask) 24 | blend_list=[0.25,0.5,0.75,1,1,1] 25 | blend_ratio = blend_list[np.random.randint(len(blend_list))] 26 | mask_blured*=blend_ratio 27 | img_blended=(mask_blured * source + (1 - mask_blured) * target) 28 | return img_blended,mask_blured 29 | 30 | def get_blend_mask(mask): 31 | H,W=mask.shape 32 | size_h=np.random.randint(192,257) 33 | size_w=np.random.randint(192,257) 34 | mask=cv2.resize(mask,(size_w,size_h)) 35 | kernel_1=random.randrange(5,26,2) 36 | kernel_1=(kernel_1,kernel_1) 37 | kernel_2=random.randrange(5,26,2) 38 | kernel_2=(kernel_2,kernel_2) 39 | 40 | mask_blured = cv2.GaussianBlur(mask, kernel_1, 0) 41 | mask_blured = mask_blured/(mask_blured.max()) 42 | mask_blured[mask_blured<1]=0 43 | 44 | mask_blured = cv2.GaussianBlur(mask_blured, kernel_2, np.random.randint(5,46)) 45 | mask_blured = mask_blured/(mask_blured.max()) 46 | mask_blured = cv2.resize(mask_blured,(W,H)) 47 | return mask_blured.reshape((mask_blured.shape+(1,))) 48 | 49 | 50 | def get_alpha_blend_mask(mask): 51 | kernel_list=[(11,11),(9,9),(7,7),(5,5),(3,3)] 52 | blend_list=[0.25,0.5,0.75] 53 | kernel_idxs=random.choices(range(len(kernel_list)), k=2) 54 | blend_ratio = blend_list[random.sample(range(len(blend_list)), 1)[0]] 55 | mask_blured = cv2.GaussianBlur(mask, kernel_list[0], 0) 56 | # print(mask_blured.max()) 57 | mask_blured[mask_blured0]=1 59 | # mask_blured = mask 60 | mask_blured = cv2.GaussianBlur(mask_blured, kernel_list[kernel_idxs[1]], 0) 61 | mask_blured = mask_blured/(mask_blured.max()) 62 | return mask_blured.reshape((mask_blured.shape+(1,))) 63 | 64 | -------------------------------------------------------------------------------- /datasets/lib/ct/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from .detector import RetinaFace 3 | from .utils import * 4 | 5 | 6 | def assert_bounded(val, low, up): 7 | return val >= low and val < up 8 | 9 | 10 | def check_valid(face, w, h): 11 | box = face[0] 12 | if box[0] > box[2]: 13 | return False 14 | if box[1] > box[3]: 15 | return False 16 | for idx, bound in zip([0, 1, 2, 3], [w, h, w, h]): 17 | if not assert_bounded(box[idx], 0, bound): 18 | return False 19 | pts = face[1] 20 | for p in pts: 21 | for idx, bound in zip([0, 1], [w, h]): 22 | if not assert_bounded(p[idx], 0, bound): 23 | return False 24 | return True 25 | 26 | 27 | def post_detect(detect_results, scale, w, h): 28 | new_results = [] 29 | for frame_faces in detect_results: 30 | new_frame_faces = [] 31 | for box, ldm, score in frame_faces: 32 | box = box * scale 33 | ldm = ldm * scale 34 | face = (box, ldm, score) 35 | if check_valid(face, w=w, h=h): 36 | new_frame_faces.append(face) 37 | new_results.append(new_frame_faces) 38 | return new_results 39 | 40 | 41 | class FaceDetector(RetinaFace): 42 | def scale_detect(self, images): 43 | max_res = 1920 44 | h, w = images[0].shape[:2] 45 | if max(h, w) > max_res: 46 | init_scale = max(h, w) / max_res 47 | else: 48 | init_scale = 1 49 | resize_scale = 2 * init_scale 50 | resize_w = int(w / resize_scale) 51 | resize_h = int(h / resize_scale) 52 | detect_input = [cv2.resize(frame, (resize_w, resize_h)) for frame in images] 53 | detect_results = post_detect( 54 | self.detect(detect_input), scale=resize_scale, w=w, h=h, 55 | ) 56 | return detect_results 57 | -------------------------------------------------------------------------------- /datasets/lib/ct/detection/detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .alignment import load_net, batch_detect 7 | 8 | 9 | def get_project_dir(): 10 | current_path = os.path.abspath(os.path.join(__file__, "../")) 11 | return current_path 12 | 13 | 14 | def relative(path): 15 | path = os.path.join(get_project_dir(), path) 16 | return os.path.abspath(path) 17 | 18 | 19 | class RetinaFace: 20 | def __init__( 21 | self, gpu_id=-1, model_path=None, network="mobilenet", 22 | ): 23 | self.gpu_id = gpu_id 24 | self.device = ( 25 | torch.device("cpu") if gpu_id == -1 else torch.device("cuda", gpu_id) 26 | ) 27 | self.model = load_net(model_path, self.device, network) 28 | 29 | def detect(self, images): 30 | if isinstance(images, np.ndarray): 31 | if len(images.shape) == 3: 32 | return batch_detect(self.model, [images], self.device)[0] 33 | elif len(images.shape) == 4: 34 | return batch_detect(self.model, images, self.device) 35 | elif isinstance(images, list): 36 | return batch_detect(self.model, np.array(images), self.device) 37 | elif isinstance(images, torch.Tensor): 38 | if len(images.shape) == 3: 39 | return batch_detect(self.model, images.unsqueeze(0), self.device)[0] 40 | elif len(images.shape) == 4: 41 | return batch_detect(self.model, images, self.device) 42 | else: 43 | raise NotImplementedError() 44 | 45 | def __call__(self, images): 46 | return self.detect(images) 47 | -------------------------------------------------------------------------------- /datasets/lib/ct/detection/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from test_tools.utils import flatten 3 | import numpy as np 4 | 5 | 6 | def chunks(l, n, step=None): 7 | if step is None: 8 | step = n 9 | return [l[i : i + n] for i in range(0, len(l), step)] 10 | 11 | 12 | def sample_chunks(l, n, step=None): 13 | return [l[i : i + n] for i in range(0, len(l), step) if i + n <= len(l)] 14 | 15 | 16 | def grab_all_frames(path, max_size, cvt=False): 17 | capture = cv2.VideoCapture(path) 18 | ret = True 19 | frames = [] 20 | while ret: 21 | ret, frame = capture.read() 22 | if ret: 23 | if cvt: 24 | frame = frame[..., ::-1] 25 | frames.append(frame) 26 | if len(frames) == max_size: 27 | break 28 | capture.release() 29 | return frames 30 | 31 | 32 | def get_clips_uniform(path, count, clip_size): 33 | capture = cv2.VideoCapture(path) 34 | n_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) 35 | max_clip_available = n_frames + 1 - clip_size 36 | if count > max_clip_available: 37 | count = max_clip_available 38 | final_start = max_clip_available - 1 39 | start_indices = np.linspace(0, final_start, count, endpoint=True, dtype=np.int) 40 | all_clip_idx = [list(range(start, start + clip_size)) for start in start_indices] 41 | valid = set(flatten(all_clip_idx)) 42 | max_idx = max(valid) 43 | 44 | frames = {} 45 | for idx in range(max_idx + 1): 46 | # Get the next frame, but don't decode if we're not using it. 47 | ret = capture.grab() 48 | if not ret: 49 | continue 50 | 51 | if idx in valid: 52 | ret, frame = capture.retrieve() 53 | if not ret or frame is None: 54 | continue 55 | else: 56 | # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 57 | frames[idx] = frame 58 | 59 | capture.release() 60 | clips = [] 61 | for clip_idx in all_clip_idx: 62 | clip = [] 63 | flag = True 64 | for idx in clip_idx: 65 | if idx not in frames: 66 | flag = False 67 | break 68 | clip.append(frames[idx]) 69 | if flag: 70 | clips.append(clip) 71 | return clips 72 | 73 | 74 | def get_valid_faces(detect_results, max_count=10, thres=0.5, at_least=False): 75 | new_results = [] 76 | for i, faces in enumerate(detect_results): 77 | if len(faces) > max_count: 78 | faces = faces[:max_count] 79 | l = [] 80 | for j, face in enumerate(faces): 81 | if face[-1] < thres and not (j == 0 and at_least): 82 | continue 83 | box, lm, score = face 84 | box = box.astype(np.float) 85 | lm = lm.astype(np.float) 86 | l.append((box, lm, score)) 87 | new_results.append(l) 88 | return new_results 89 | 90 | 91 | def scale_box(box, scale_h, scale_w, h, w): 92 | x1, y1, x2, y2 = box.astype(np.int32) 93 | center_x = (x1 + x2) // 2 94 | center_y = (y1 + y2) // 2 95 | box_h = int((y2 - y1) * scale_h) 96 | box_w = int((x2 - x1) * scale_w) 97 | new_x1 = center_x - box_w // 2 98 | new_x2 = new_x1 + box_w 99 | new_y1 = center_y - box_h // 2 100 | new_y2 = new_y1 + box_h 101 | new_x1 = max(new_x1, 0) 102 | new_y1 = max(new_y1, 0) 103 | new_y2 = min(new_y2, h) 104 | new_x2 = min(new_x2, w) 105 | return new_x1, new_y1, new_x2, new_y2 106 | 107 | 108 | def get_bbox(detect_res): 109 | tmp_detect_res = get_valid_faces(detect_res, max_count=4, thres=0.5) 110 | all_face_bboxs = [] 111 | for faces in tmp_detect_res: 112 | all_face_bboxs.extend([face[0] for face in faces]) 113 | all_face_bboxs = np.array(all_face_bboxs).astype(np.int) 114 | x1 = all_face_bboxs[:, 0].min() 115 | x2 = all_face_bboxs[:, 2].max() 116 | y1 = all_face_bboxs[:, 1].min() 117 | y2 = all_face_bboxs[:, 3].max() 118 | 119 | return x1, y1, x2, y2 120 | 121 | 122 | def delta_detect_res(detect_res, x1, y1): 123 | diff = np.array([[x1, y1]]) 124 | new_detect_res = [] 125 | for faces in detect_res: 126 | f = [] 127 | for face in faces: 128 | box, lm, score = face 129 | box = box.astype(np.float) 130 | box[[0, 2]] -= x1 131 | box[[1, 3]] -= y1 132 | lm = lm.astype(np.float) - diff 133 | f.append((box, lm, score)) 134 | new_detect_res.append(f) 135 | return new_detect_res 136 | 137 | 138 | def pre_crop(clips, detect_res): 139 | box = np.array(get_bbox(detect_res)) 140 | w = box[2] - box[0] 141 | h = box[3] - box[1] 142 | x1, y1, x2, y2 = scale_box( 143 | box, 1.5, 1.2 if w > 2 * h else 1.5, clips[0].shape[0], clips[0].shape[1] 144 | ) 145 | clips = np.array(clips) 146 | return clips[:, y1:y2, x1:x2], delta_detect_res(detect_res, x1, y1) 147 | -------------------------------------------------------------------------------- /datasets/lib/ct/face_alignment/__init__.py: -------------------------------------------------------------------------------- 1 | from .predictor import LandmarkPredictor -------------------------------------------------------------------------------- /datasets/lib/ct/face_alignment/basenet.py: -------------------------------------------------------------------------------- 1 | # Backbone networks used for face landmark detection 2 | # Cunjian Chen (cunjian@msu.edu) 3 | 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False): 10 | super(ConvBlock, self).__init__() 11 | self.linear = linear 12 | if dw: 13 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False) 14 | else: 15 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False) 16 | self.bn = nn.BatchNorm2d(oup) 17 | if not linear: 18 | self.prelu = nn.PReLU(oup) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | x = self.bn(x) 23 | if self.linear: 24 | return x 25 | else: 26 | return self.prelu(x) 27 | 28 | 29 | # SE module 30 | # https://github.com/wujiyang/Face_Pytorch/blob/master/backbone/cbam.py 31 | class SEModule(nn.Module): 32 | """Squeeze and Excitation Module""" 33 | 34 | def __init__(self, channels, reduction): 35 | super(SEModule, self).__init__() 36 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 37 | self.fc1 = nn.Conv2d( 38 | channels, channels // reduction, kernel_size=1, padding=0, bias=False 39 | ) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.fc2 = nn.Conv2d( 42 | channels // reduction, channels, kernel_size=1, padding=0, bias=False 43 | ) 44 | self.sigmoid = nn.Sigmoid() 45 | 46 | def forward(self, x): 47 | input = x 48 | x = self.avg_pool(x) 49 | x = self.fc1(x) 50 | x = self.relu(x) 51 | x = self.fc2(x) 52 | x = self.sigmoid(x) 53 | 54 | return input * x 55 | 56 | 57 | # USE global depthwise convolution layer. Compatible with MobileNetV2 (224×224), MobileNetV2_ExternalData (224×224) 58 | class MobileNet_GDConv(nn.Module): 59 | def __init__(self, num_classes): 60 | super(MobileNet_GDConv, self).__init__() 61 | self.pretrain_net = models.mobilenet_v2(pretrained=False) 62 | self.base_net = nn.Sequential(*list(self.pretrain_net.children())[:-1]) 63 | self.linear7 = ConvBlock(1280, 1280, (7, 7), 1, 0, dw=True, linear=True) 64 | self.linear1 = ConvBlock(1280, num_classes, 1, 1, 0, linear=True) 65 | 66 | def forward(self, x): 67 | x = self.base_net(x) 68 | x = self.linear7(x) 69 | x = self.linear1(x) 70 | x = x.view(x.size(0), -1) 71 | return x 72 | 73 | 74 | # USE global depthwise convolution layer. Compatible with MobileNetV2 (56×56) 75 | class MobileNet_GDConv_56(nn.Module): 76 | def __init__(self, num_classes): 77 | super(MobileNet_GDConv_56, self).__init__() 78 | self.pretrain_net = models.mobilenet_v2(pretrained=False) 79 | self.base_net = nn.Sequential(*list(self.pretrain_net.children())[:-1]) 80 | self.linear7 = ConvBlock(1280, 1280, (2, 2), 1, 0, dw=True, linear=True) 81 | self.linear1 = ConvBlock(1280, num_classes, 1, 1, 0, linear=True) 82 | 83 | def forward(self, x): 84 | x = self.base_net(x) 85 | x = self.linear7(x) 86 | x = self.linear1(x) 87 | x = x.view(x.size(0), -1) 88 | return x 89 | 90 | 91 | # MobileNetV2 with SE; Compatible with MobileNetV2_SE (224×224) and MobileNetV2_SE_RE (224×224) 92 | class MobileNet_GDConv_SE(nn.Module): 93 | def __init__(self, num_classes): 94 | super(MobileNet_GDConv_SE, self).__init__() 95 | self.pretrain_net = models.mobilenet_v2(pretrained=True) 96 | self.base_net = nn.Sequential(*list(self.pretrain_net.children())[:-1]) 97 | self.linear7 = ConvBlock(1280, 1280, (7, 7), 1, 0, dw=True, linear=True) 98 | self.linear1 = ConvBlock(1280, num_classes, 1, 1, 0, linear=True) 99 | self.attention = SEModule(1280, 8) 100 | 101 | def forward(self, x): 102 | x = self.base_net(x) 103 | x = self.attention(x) 104 | x = self.linear7(x) 105 | x = self.linear1(x) 106 | x = x.view(x.size(0), -1) 107 | return x 108 | -------------------------------------------------------------------------------- /datasets/lib/ct/face_alignment/predictor.py: -------------------------------------------------------------------------------- 1 | # Face alignment demo 2 | # Uses MTCNN as face detector 3 | # Cunjian Chen (ccunjian@gmail.com) 4 | import torch 5 | import cv2 6 | import numpy as np 7 | from torch.utils.data import DataLoader 8 | from .basenet import MobileNet_GDConv 9 | 10 | 11 | def get_device(gpu_id): 12 | if gpu_id > -1: 13 | return torch.device(f"cuda:{str(gpu_id)}") 14 | else: 15 | return torch.device("cpu") 16 | 17 | 18 | def load_model(file): 19 | model = MobileNet_GDConv(136) 20 | if file is not None: 21 | model.load_state_dict(torch.load(file, map_location="cpu")) 22 | else: 23 | url = "https://github.com/yinglinzheng/face_weights/releases/download/v1/mobilenet_224_model_best_gdconv_external.pth" 24 | model.load_state_dict(torch.utils.model_zoo.load_url(url)) 25 | return model 26 | 27 | 28 | # landmark of (5L, 2L) from [0,1] to real range 29 | def reproject(bbox, landmark): 30 | landmark_ = landmark.clone() 31 | x1, y1, x2, y2 = bbox 32 | w = x2 - x1 33 | h = y2 - y1 34 | landmark_[:, 0] *= w 35 | landmark_[:, 0] += x1 36 | landmark_[:, 1] *= h 37 | landmark_[:, 1] += y1 38 | return landmark_ 39 | 40 | 41 | def prepare_feed(img, face): 42 | height, width, _ = img.shape 43 | mean = np.asarray([0.485, 0.456, 0.406]) 44 | std = np.asarray([0.229, 0.224, 0.225]) 45 | out_size = 224 46 | x1, y1, x2, y2 = face[:4] 47 | 48 | w = x2 - x1 + 1 49 | h = y2 - y1 + 1 50 | size = int(min([w, h]) * 1.2) 51 | cx = x1 + w // 2 52 | cy = y1 + h // 2 53 | x1 = cx - size // 2 54 | x2 = x1 + size 55 | y1 = cy - size // 2 56 | y2 = y1 + size 57 | 58 | dx = max(0, -x1) 59 | dy = max(0, -y1) 60 | x1 = max(0, x1) 61 | y1 = max(0, y1) 62 | 63 | edx = max(0, x2 - width) 64 | edy = max(0, y2 - height) 65 | x2 = min(width, x2) 66 | y2 = min(height, y2) 67 | new_bbox = torch.Tensor([x1, y1, x2, y2]).int() 68 | x1, y1, x2, y2 = new_bbox 69 | cropped = img[y1:y2, x1:x2] 70 | if dx > 0 or dy > 0 or edx > 0 or edy > 0: 71 | cropped = cv2.copyMakeBorder( 72 | cropped, int(dy), int(edy), int(dx), int(edx), cv2.BORDER_CONSTANT, 0 73 | ) 74 | cropped_face = cv2.resize(cropped, (out_size, out_size)) 75 | 76 | if cropped_face.shape[0] <= 0 or cropped_face.shape[1] <= 0: 77 | return None 78 | test_face = cropped_face.copy() 79 | test_face = test_face / 255.0 80 | test_face = (test_face - mean) / std 81 | test_face = test_face.transpose((2, 0, 1)) 82 | data = torch.from_numpy(test_face).float() 83 | return dict(data=data, bbox=new_bbox) 84 | 85 | 86 | @torch.no_grad() 87 | def single_predict(model, feed, device): 88 | landmark = model(feed["data"].unsqueeze(0).to(device)).cpu() 89 | landmark = landmark.reshape(-1, 2) 90 | landmark = reproject(feed["bbox"], landmark) 91 | return landmark.numpy() 92 | 93 | 94 | @torch.no_grad() 95 | def batch_predict(model, feeds, device): 96 | if not isinstance(feeds, list): 97 | feeds = [feeds] 98 | # loader = DataLoader(FeedDataset(feeds), batch_size=50, shuffle=False) 99 | data = [] 100 | for feed in feeds: 101 | data.append(feed["data"].unsqueeze(0)) 102 | data = torch.cat(data, 0).to(device) 103 | results = [] 104 | 105 | landmarks = model(data).cpu() 106 | for landmark, feed in zip(landmarks, feeds): 107 | landmark = landmark.reshape(-1, 2) 108 | landmark = reproject(feed["bbox"], landmark) 109 | results.append(landmark.numpy()) 110 | return results 111 | 112 | 113 | @torch.no_grad() 114 | def batch_predict2(model, feeds, device, batch_size=None): 115 | if not isinstance(feeds, list): 116 | feeds = [feeds] 117 | if batch_size is None: 118 | batch_size = len(feeds) 119 | loader = DataLoader(feeds, batch_size=len(feeds), shuffle=False) 120 | results = [] 121 | for feed in loader: 122 | landmarks = model(feed["data"].to(device)).cpu() 123 | for landmark, bbox in zip(landmarks, feed["bbox"]): 124 | landmark = landmark.reshape(-1, 2) 125 | landmark = reproject(bbox, landmark) 126 | results.append(landmark.numpy()) 127 | return results 128 | 129 | 130 | class LandmarkPredictor: 131 | def __init__(self, gpu_id=0, file=None): 132 | self.device = get_device(gpu_id) 133 | self.model = load_model(file).to(self.device).eval() 134 | 135 | def __call__(self, feeds): 136 | results = batch_predict2(self.model, feeds, self.device) 137 | if not isinstance(feeds, list): 138 | results = results[0] 139 | return results 140 | 141 | @staticmethod 142 | def prepare_feed(img, face): 143 | return prepare_feed(img, face) 144 | -------------------------------------------------------------------------------- /datasets/lib/ct/face_alignment/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def drawLandmark_multiple(img, bbox, landmark): 5 | """ 6 | Input: 7 | - img: gray or RGB 8 | - bbox: type of BBox 9 | - landmark: reproject landmark of (5L, 2L) 10 | Output: 11 | - img marked with landmark and bbox 12 | """ 13 | x1, y1, x2, y2 = bbox 14 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) 15 | for x, y in landmark: 16 | cv2.circle(img, (int(x), int(y)), 2, (0, 255, 0), -1) 17 | return img 18 | -------------------------------------------------------------------------------- /datasets/lib/ct/faster_crop_align_xray.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from .warp_for_xray import ( 4 | estimiate_batch_transform, 5 | transform_landmarks, 6 | std_points_256, 7 | ) 8 | import numpy as np 9 | 10 | 11 | class FasterCropAlignXRay: 12 | """ 13 | 修正到统一坐标系,统一图像大小到标准尺寸 14 | """ 15 | 16 | def __init__(self, size=256): 17 | self.image_size = size 18 | self.std_points = std_points_256 * size / 256.0 19 | 20 | def __call__(self, landmarks68, five_landmarks, ori_boxes, images=None, jitter=False): 21 | # landmarks = [landmark[:4] for landmark in landmarks] 22 | # ori_boxes = np.array([ori_box for _, _, _, ori_box in landmarks]) 23 | # five_landmarks = np.array([ldm5 for _, ldm5, _, _ in landmarks]) 24 | # landmarks68 = np.array([ldm68 for _, _, ldm68, _ in landmarks]) 25 | # assert landmarks68.min() > 0 26 | 27 | 28 | 29 | left_top = ori_boxes[:, :2].min(0) 30 | 31 | right_bottom = ori_boxes[:, 2:].max(0) 32 | 33 | size = right_bottom - left_top 34 | 35 | w, h = size 36 | diff = ori_boxes[:, :2] - left_top[None, ...] 37 | 38 | new_five_landmarks = five_landmarks + diff[:, None, :] 39 | new_landmarks68 = landmarks68 + diff[:, None, :] 40 | 41 | landmark_for_estimiate = new_five_landmarks.copy() 42 | if jitter: 43 | landmark_for_estimiate += np.random.uniform( 44 | -4, 4, landmark_for_estimiate.shape 45 | ) 46 | 47 | tfm, trans = estimiate_batch_transform( 48 | landmark_for_estimiate, tgt_pts=self.std_points 49 | ) 50 | 51 | transformed_landmarks68 = np.array( 52 | [transform_landmarks(ldm68, trans) for ldm68 in new_landmarks68] 53 | ) 54 | 55 | transformed_landmarks5 = np.array( 56 | [transform_landmarks(ldm68, trans) for ldm68 in new_five_landmarks] 57 | ) 58 | 59 | if images is not None: 60 | transformed_images = [ 61 | self.process_sinlge(tfm, image, d, h, w) 62 | for image, d in zip(images, diff) 63 | ] # 拼接 func 的参数 64 | transformed_images = np.stack(transformed_images) 65 | return transformed_landmarks68,transformed_landmarks5, transformed_images 66 | else: 67 | return transformed_landmarks68,transformed_landmarks5 68 | 69 | def process_sinlge(self, tfm, image, d, h, w): 70 | assert isinstance(image, np.ndarray) 71 | new_image = np.zeros((h, w, 3), dtype=np.uint8) 72 | x, y = d 73 | ih, iw, _ = image.shape 74 | new_image[y : y + ih, x : x + iw] = image 75 | transformed_image = cv2.warpAffine( 76 | new_image, tfm, (self.image_size, self.image_size) 77 | ) 78 | return transformed_image 79 | -------------------------------------------------------------------------------- /datasets/lib/ct/operations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | from .tracking.sort import iou 7 | 8 | 9 | def face_iou(f1, f2): 10 | return iou(f1[0], f2[0]) 11 | 12 | 13 | def simple_tracking(batch_landmarks, index=0, thres=0.5): 14 | track = [] 15 | 16 | for i, faces in enumerate(batch_landmarks): 17 | if i == 0: 18 | if len(faces) <= index or faces[index][-1] < 0.8: 19 | return None 20 | if index != 0: 21 | for idx in range(index): 22 | if face_iou(faces[idx], faces[index]) > thres: 23 | return None 24 | track.append(faces[index]) 25 | else: 26 | last = track[i - 1] 27 | if len(faces) == 0: 28 | return None 29 | sorted_faces = sorted(faces, key=lambda x: face_iou(x, last), reverse=True) 30 | if face_iou(sorted_faces[0], last) < thres: 31 | return None 32 | track.append(sorted_faces[0]) 33 | return track 34 | 35 | 36 | def multiple_tracking(batch_landmarks): 37 | tracks = [] 38 | for i in range(len(batch_landmarks[0])): 39 | track = simple_tracking(batch_landmarks, index=i) 40 | if track is None: 41 | continue 42 | tracks.append(track) 43 | return tracks 44 | 45 | def find_longest(detect_res): 46 | fc = len(detect_res) 47 | tuples = [] 48 | start = 0 49 | end = 0 50 | previous_count = -1 51 | all_tracks = [] 52 | # start 取得到,end 取不到 53 | while start < (fc - 1): 54 | for end in range(start + 2, fc + 1): 55 | tracks = multiple_tracking(detect_res[start:end]) 56 | if (len(tracks) != previous_count and previous_count != -1) or len( 57 | tracks 58 | ) == 0: 59 | break 60 | previous_count = len(tracks) 61 | if end - start > 2: 62 | if end != fc: 63 | un_reach_end = end - 1 64 | else: 65 | un_reach_end = end 66 | sub_tracks = multiple_tracking(detect_res[start:un_reach_end]) 67 | if end == fc and len(sub_tracks) == 0: 68 | un_reach_end = end - 1 69 | sub_tracks = multiple_tracking(detect_res[start:un_reach_end]) 70 | if len(sub_tracks) > 0: 71 | tpl = (start, un_reach_end) 72 | tuples.append(tpl) 73 | all_tracks.append(sub_tracks[0]) 74 | else: 75 | raise NotImplementedError 76 | previous_count = -1 77 | end = un_reach_end 78 | start = end 79 | return tuples, all_tracks -------------------------------------------------------------------------------- /datasets/lib/ct/tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/datasets/lib/ct/tracking/__init__.py -------------------------------------------------------------------------------- /datasets/lib/ct/tracking/tracker.py: -------------------------------------------------------------------------------- 1 | from .sort import Sort 2 | import numpy as np 3 | 4 | 5 | def get_detections(faces): 6 | detections = [] 7 | for face in faces: 8 | x1, y1, x2, y2 = face[0] 9 | detections.append((x1, y1, x2, y2, face[-1])) 10 | return np.array(detections) 11 | 12 | 13 | def get_tracks(detect_results): 14 | tracks = {} 15 | mot_tracker = Sort() 16 | for faces in detect_results: 17 | detections = get_detections(faces) 18 | track_bbs_ids = mot_tracker.update(detections) 19 | for track in track_bbs_ids: # 单独框出每一张人脸 20 | id = int(track[-1]) 21 | box = track[:4] 22 | if id in tracks: 23 | tracks[id].append(box) 24 | else: 25 | tracks[id] = [box] 26 | 27 | return [track for id, track in tracks.items() if len(track) == len(detect_results)] 28 | -------------------------------------------------------------------------------- /datasets/lib/ct/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def write_img(file, img): 5 | cv2.imwrite(file, img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 6 | def get_crop_box(shape, box, scale=0.5): 7 | height, width = shape 8 | box = np.rint(box).astype(np.int) 9 | new_box = box.reshape(2, 2) 10 | size = new_box[1] - new_box[0] 11 | diff = scale * size 12 | diff = diff[None, :] * np.array([-1, 1])[:, None] 13 | new_box = new_box + diff 14 | new_box[:, 0] = np.clip(new_box[:, 0], 0, width - 1) 15 | new_box[:, 1] = np.clip(new_box[:, 1], 0, height - 1) 16 | new_box = np.rint(new_box).astype(np.int) 17 | return new_box.reshape(-1) -------------------------------------------------------------------------------- /datasets/utils/blend.py: -------------------------------------------------------------------------------- 1 | # Created by: Kaede Shiohara 2 | # Yamasaki Lab at The University of Tokyo 3 | # shiohara@cvm.t.u-tokyo.ac.jp 4 | # Copyright (c) 2021 5 | # 3rd party softwares' licenses are noticed at https://github.com/mapooon/SelfBlendedImages/blob/master/LICENSE 6 | 7 | import cv2 8 | import numpy as np 9 | import scipy as sp 10 | from skimage.measure import label, regionprops 11 | import random 12 | from PIL import Image 13 | import sys 14 | 15 | 16 | 17 | def alpha_blend(source,target,mask): 18 | mask_blured = get_blend_mask(mask) 19 | img_blended=(mask_blured * source + (1 - mask_blured) * target) 20 | return img_blended,mask_blured 21 | 22 | def dynamic_blend(source,target,mask): 23 | mask_blured = get_blend_mask(mask) 24 | blend_list=[0.25,0.5,0.75,1,1,1] 25 | blend_ratio = blend_list[np.random.randint(len(blend_list))] 26 | mask_blured*=blend_ratio 27 | img_blended=(mask_blured * source + (1 - mask_blured) * target) 28 | return img_blended,mask_blured 29 | 30 | def get_blend_mask(mask): 31 | H,W=mask.shape 32 | size_h=np.random.randint(192,257) 33 | size_w=np.random.randint(192,257) 34 | mask=cv2.resize(mask,(size_w,size_h)) 35 | kernel_1=random.randrange(5,26,2) 36 | kernel_1=(kernel_1,kernel_1) 37 | kernel_2=random.randrange(5,26,2) 38 | kernel_2=(kernel_2,kernel_2) 39 | 40 | mask_blured = cv2.GaussianBlur(mask, kernel_1, 0) 41 | mask_blured = mask_blured/(mask_blured.max()) 42 | mask_blured[mask_blured<1]=0 43 | 44 | mask_blured = cv2.GaussianBlur(mask_blured, kernel_2, np.random.randint(5,46)) 45 | mask_blured = mask_blured/(mask_blured.max()) 46 | mask_blured = cv2.resize(mask_blured,(W,H)) 47 | return mask_blured.reshape((mask_blured.shape+(1,))) 48 | 49 | 50 | def get_alpha_blend_mask(mask): 51 | kernel_list=[(11,11),(9,9),(7,7),(5,5),(3,3)] 52 | blend_list=[0.25,0.5,0.75] 53 | kernel_idxs=random.choices(range(len(kernel_list)), k=2) 54 | blend_ratio = blend_list[random.sample(range(len(blend_list)), 1)[0]] 55 | mask_blured = cv2.GaussianBlur(mask, kernel_list[0], 0) 56 | # print(mask_blured.max()) 57 | mask_blured[mask_blured0]=1 59 | # mask_blured = mask 60 | mask_blured = cv2.GaussianBlur(mask_blured, kernel_list[kernel_idxs[1]], 0) 61 | mask_blured = mask_blured/(mask_blured.max()) 62 | return mask_blured.reshape((mask_blured.shape+(1,))) 63 | 64 | -------------------------------------------------------------------------------- /datasets/utils/dataloader_util.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pylab import * 4 | import matplotlib.font_manager as fm # to create font 5 | import json 6 | import pandas as pd 7 | import math 8 | from glob import glob 9 | import random 10 | seed_value = 1234 11 | random.seed(seed_value) 12 | REAL_LABLE = 0 13 | FAKE_LABEL = 1 14 | 15 | dataset_root = { 16 | 'FaceForensics_c23_train': 'FF++/face_v1/', 17 | 'FaceForensics': 'FF_frames/face_v1/', 18 | 'Celeb-DF': 'Celeb-DF_frames/face_v1/', 19 | 'DFDC': 'DFDC_frames/face_v1/', 20 | 'FFIW':'../FFIW/',} 21 | 22 | def get_data_list(dataset_name, base_root, split, only_real=False): 23 | dataset_info = [] 24 | if "FaceForensics" in dataset_name: 25 | compress = dataset_name.split("_")[1] 26 | if split == 'train': 27 | root = os.path.join(base_root, dataset_root[dataset_name+"_train"]) 28 | else: 29 | root = os.path.join(base_root, dataset_root[dataset_name.split("_")[0]]) 30 | dataset_info = get_FF_list(root, split, compress=compress, only_real=only_real) 31 | elif dataset_name == 'Celeb-DF' and split=='test': 32 | root = os.path.join(base_root, dataset_root[dataset_name]) 33 | video_list_txt = os.path.join(root, 'List_of_testing_videos.txt') 34 | with open(video_list_txt) as f: 35 | for data in f: 36 | line=data.split() 37 | dataset_info.append((line[1][:-4],FAKE_LABEL-int(line[0]))) 38 | elif dataset_name == 'DFDC' and split=='test': 39 | root = os.path.join(base_root, dataset_root[dataset_name]) 40 | label=pd.read_csv(root+'labels.csv',delimiter=',') 41 | dataset_info = [(video_name[:-4], label) for video_name, label in zip(label['filename'].tolist(), label['label'].tolist())] 42 | root = root+'test_videos/' 43 | elif dataset_name == 'FFIW' and split=='test': 44 | root = '' 45 | real_root = os.path.join(dataset_root[dataset_name],'source') 46 | fake_root = os.path.join(dataset_root[dataset_name],'target') 47 | real_path = glob(real_root + '/*', recursive=True) 48 | fake_path = glob(fake_root + '/*', recursive=True) 49 | for i,path in enumerate(real_path): 50 | dataset_info.append((path, REAL_LABLE)) 51 | for i,path in enumerate(fake_path): 52 | dataset_info.append((path, FAKE_LABEL)) 53 | else: 54 | print('not support!', dataset_name) 55 | assert 0 56 | return dataset_info, root 57 | 58 | 59 | def get_FF_list(root, split, compress='c23', only_real=False): 60 | split_json_path = os.path.join(root, 'splits', f'{split}.json') 61 | json_data = json.load(open(split_json_path, 'r')) 62 | if only_real: 63 | real_names = [] 64 | for item in json_data: 65 | real_names.extend([item[0], item[1]]) 66 | real_video_dir = os.path.join('original_sequences', 'youtube', compress, 'videos') 67 | dataset_info = [[os.path.join(real_video_dir,x), REAL_LABLE] for x in real_names] 68 | else: 69 | real_names = [] 70 | fake_names = [] 71 | for item in json_data: 72 | real_names.extend([item[0], item[1]]) 73 | fake_names.extend([f'{item[0]}_{item[1]}', f'{item[1]}_{item[0]}']) 74 | real_video_dir = os.path.join('original_sequences', 'youtube', compress, 'videos') 75 | dataset_info = [[os.path.join(real_video_dir,x), 0] for x in real_names] 76 | ff_fake_types = ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures'] 77 | for method in ff_fake_types: 78 | fake_video_dir = os.path.join('manipulated_sequences', method, compress, 'videos') 79 | for x in fake_names: 80 | dataset_info.append((os.path.join(fake_video_dir,x),FAKE_LABEL)) 81 | return dataset_info 82 | 83 | 84 | def check_frame_len(video_len, num_segments): 85 | inner_index = list(range(video_len)) 86 | pad_length = math.ceil((num_segments-video_len)/2) 87 | post_module = inner_index[1:-1][::-1] + inner_index 88 | l_post = len(post_module) 89 | post_module = post_module * (pad_length // l_post + 1) 90 | post_module = post_module[:pad_length] 91 | assert len(post_module) == pad_length 92 | pre_module = inner_index + inner_index[1:-1][::-1] 93 | l_pre = len(post_module) 94 | pre_module = pre_module * (pad_length // l_pre + 1) 95 | pre_module = pre_module[-pad_length:] 96 | assert len(pre_module) == pad_length 97 | sampled_clip_idxs = pre_module + inner_index + post_module 98 | sampled_clip_idxs = sampled_clip_idxs[:num_segments] 99 | return sampled_clip_idxs 100 | -------------------------------------------------------------------------------- /datasets/utils/funcs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | from glob import glob 6 | import os 7 | import pandas as pd 8 | import albumentations as alb 9 | import cv2 10 | 11 | def load_json(path): 12 | d = {} 13 | with open(path, mode="r") as f: 14 | d = json.load(f) 15 | return d 16 | 17 | 18 | def IoUfrom2bboxes(boxA, boxB): 19 | # determine the (x, y)-coordinates of the intersection rectangle 20 | xA = max(boxA[0], boxB[0]) 21 | yA = max(boxA[1], boxB[1]) 22 | xB = min(boxA[2], boxB[2]) 23 | yB = min(boxA[3], boxB[3]) 24 | # compute the area of intersection rectangle 25 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 26 | # compute the area of both the prediction and ground-truth 27 | # rectangles 28 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 29 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 30 | # compute the intersection over union by taking the intersection 31 | # area and dividing it by the sum of prediction + ground-truth 32 | # areas - the interesection area 33 | iou = interArea / float(boxAArea + boxBArea - interArea) 34 | # return the intersection over union value 35 | return iou 36 | 37 | 38 | 39 | def crop_face(img,landmark=None,bbox=None,margin=False,crop_by_bbox=True,abs_coord=False,only_img=False,phase='train'): 40 | assert phase in ['train','val','test'] 41 | 42 | #crop face------------------------------------------ 43 | H,W=len(img),len(img[0]) 44 | 45 | assert landmark is not None or bbox is not None 46 | 47 | H,W=len(img),len(img[0]) 48 | 49 | if crop_by_bbox: 50 | x0,y0=bbox[0] 51 | x1,y1=bbox[1] 52 | w=x1-x0 53 | h=y1-y0 54 | w0_margin=w/4#0#np.random.rand()*(w/8) 55 | w1_margin=w/4 56 | h0_margin=h/4#0#np.random.rand()*(h/5) 57 | h1_margin=h/4 58 | else: 59 | x0,y0=landmark[:68,0].min(),landmark[:68,1].min() 60 | x1,y1=landmark[:68,0].max(),landmark[:68,1].max() 61 | w=x1-x0 62 | h=y1-y0 63 | w0_margin=w/8#0#np.random.rand()*(w/8) 64 | w1_margin=w/8 65 | h0_margin=h/2#0#np.random.rand()*(h/5) 66 | h1_margin=h/5 67 | 68 | 69 | 70 | if margin: 71 | w0_margin*=4 72 | w1_margin*=4 73 | h0_margin*=2 74 | h1_margin*=2 75 | elif phase=='train': 76 | w0_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 77 | w1_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 78 | h0_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 79 | h1_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 80 | else: 81 | w0_margin*=0.5 82 | w1_margin*=0.5 83 | h0_margin*=0.5 84 | h1_margin*=0.5 85 | 86 | y0_new=max(0,int(y0-h0_margin)) 87 | y1_new=min(H,int(y1+h1_margin)+1) 88 | x0_new=max(0,int(x0-w0_margin)) 89 | x1_new=min(W,int(x1+w1_margin)+1) 90 | 91 | img_cropped=img[y0_new:y1_new,x0_new:x1_new] 92 | if landmark is not None: 93 | landmark_cropped=np.zeros_like(landmark) 94 | for i,(p,q) in enumerate(landmark): 95 | landmark_cropped[i]=[p-x0_new,q-y0_new] 96 | else: 97 | landmark_cropped=None 98 | if bbox is not None: 99 | bbox_cropped=np.zeros_like(bbox) 100 | for i,(p,q) in enumerate(bbox): 101 | bbox_cropped[i]=[p-x0_new,q-y0_new] 102 | else: 103 | bbox_cropped=None 104 | 105 | if only_img: 106 | return img_cropped 107 | if abs_coord: 108 | return img_cropped,landmark_cropped,bbox_cropped,(y0-y0_new,x0-x0_new,y1_new-y1,x1_new-x1),y0_new,y1_new,x0_new,x1_new 109 | else: 110 | return img_cropped,landmark_cropped,bbox_cropped,(y0-y0_new,x0-x0_new,y1_new-y1,x1_new-x1) 111 | 112 | 113 | class RandomDownScale(alb.core.transforms_interface.ImageOnlyTransform): 114 | def apply(self,img,**params): 115 | return self.randomdownscale(img) 116 | 117 | def randomdownscale(self,img): 118 | keep_ratio=True 119 | keep_input_shape=True 120 | H,W,C=img.shape 121 | ratio_list=[2,4] 122 | r=ratio_list[np.random.randint(len(ratio_list))] 123 | img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST) 124 | if keep_input_shape: 125 | img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR) 126 | 127 | return img_ds -------------------------------------------------------------------------------- /datasets/utils/initialize.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | from PIL import Image 7 | from glob import glob 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def init_ff(phase,level='frame',n_frames=8): 13 | dataset_path='data/FaceForensics++/original_sequences/youtube/raw/frames/' 14 | 15 | 16 | image_list=[] 17 | label_list=[] 18 | 19 | 20 | 21 | folder_list = sorted(glob(dataset_path+'*')) 22 | filelist = [] 23 | list_dict = json.load(open(f'data/FaceForensics++/{phase}.json','r')) 24 | for i in list_dict: 25 | filelist+=i 26 | folder_list = [i for i in folder_list if os.path.basename(i)[:3] in filelist] 27 | 28 | if level =='video': 29 | label_list=[0]*len(folder_list) 30 | return folder_list,label_list 31 | for i in range(len(folder_list)): 32 | # images_temp=sorted([glob(folder_list[i]+'/*.png')[0]]) 33 | images_temp=sorted(glob(folder_list[i]+'/*.png')) 34 | if n_frames 1: 78 | labels,_ = torch.max(labels,1) 79 | y_outputs.extend(outputs) 80 | y_labels.extend(labels) 81 | y_idxes.extend(idxes) 82 | gather_y_outputs = gather_tensor(y_outputs, args.world_size, to_numpy=False) 83 | gather_y_labels = gather_tensor(y_labels, args.world_size, to_numpy=False) 84 | gather_y_idxes = gather_tensor(y_idxes, args.world_size, to_numpy=False) 85 | test_result_list = [] 86 | for i, idx in enumerate(gather_y_idxes): 87 | video_name = dataloader.dataset.all_list[idx][0][0] 88 | video_name_tmp = video_name.split("/") 89 | video_name = video_name[2:].replace('/'+video_name_tmp[-2]+'/'+video_name_tmp[-1], "").replace('/'+video_name_tmp[1]+"/", '') 90 | video_label = gather_y_labels[i].cpu().item() 91 | video_predict = gather_y_outputs[i].cpu().item() 92 | test_result_list.append([video_name, video_label, video_predict]) 93 | test_result_list = sorted(test_result_list, key=(lambda x:x[0])) 94 | result_dir = args.model.resume.replace('ckpt/','')[:-4] 95 | result_dir = os.path.join(result_dir, test_label) 96 | os.makedirs(result_dir, exist_ok=True) 97 | predict_file = result_dir+"/"+args.final_test.dataset.params.method+".csv" 98 | pd.DataFrame(test_result_list, columns=["video", "label", "predict"]).to_csv(predict_file, index=False) 99 | config_file = args.config 100 | config_file_name = os.path.basename(config_file) 101 | copyfile(config_file, os.path.join(result_dir, config_file_name)) 102 | auc, acc = video_evaluation.final_scores(result_file=predict_file) 103 | return auc, acc 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /face_process/custom_data_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_image_files(directory): 5 | image_files = [] 6 | 7 | for root, dirs, files in os.walk(directory): 8 | for file in files: 9 | if file.endswith(".jpg") or file.endswith(".png"): 10 | image_files.append(os.path.join(root, file)) 11 | 12 | return image_files 13 | 14 | 15 | def get_list(root=None): 16 | #Celeb-v 17 | root = '/HDD0/guozonghui/project/datasets/celeb-v/frames' 18 | celeb_v = get_image_files(root) 19 | print('celeb-v:', len(celeb_v)) 20 | #FFHQ 21 | root = '/SSD2/shiliang/work/data/FFHQ_256' 22 | ffhq = get_image_files(root) 23 | print('celeb-v:', len(ffhq)) 24 | 25 | if __name__=='__main__': 26 | get_list() -------------------------------------------------------------------------------- /face_process/extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | def extract_random_frames(video_num, video_path, output_dir, num_frames=1): 7 | cap = cv2.VideoCapture(video_path) 8 | 9 | # Get the total number of frames in the video 10 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 11 | 12 | # Generate five random frame numbers 13 | frame_nums = np.sort(np.random.randint(0, total_frames, num_frames)) 14 | 15 | for frame_num in frame_nums: 16 | # Set the position of the video file to the frame number we want to capture 17 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) 18 | 19 | ret, frame = cap.read() 20 | 21 | # If the frame was read successfully, save it 22 | if ret: 23 | output_dir_ = os.path.join(output_dir, f'{video_num}') 24 | os.makedirs(output_dir_, exist_ok=True) 25 | output_path = os.path.join(output_dir_, f'{frame_num}.png') 26 | print(output_path) 27 | cv2.imwrite(output_path, frame) 28 | 29 | cap.release() 30 | 31 | def main(input_dir, output_dir): 32 | # Create the output directory if it doesn't already exist 33 | os.makedirs(output_dir, exist_ok=True) 34 | 35 | # Iterate over all files in the input directory 36 | print(len(os.listdir(input_dir))) 37 | video_num = 0 38 | for filename in tqdm(os.listdir(input_dir)): 39 | if filename.endswith(".mp4"): 40 | video_path = os.path.join(input_dir, filename) 41 | try: 42 | extract_random_frames(video_num, video_path, output_dir, num_frames=num_frames) 43 | video_num +=1 44 | except: 45 | print(filename) 46 | 47 | 48 | if __name__ == "__main__": 49 | num_frames = 8 50 | # methods = ['Celeb-DF', 'DFDC', 'DiffHead', 'FF', 'FFIW', 'hrfae', 'iplap', 'makeittalker', 'mobileswap', 'sadtalker', 'styleHEAT', 'VIPL', 'wav2lip'] 51 | methods = [ 'DFDC'] 52 | 53 | for method in methods: 54 | main(f'/SSD0/guozonghui/project/FFD/Real_video/Video/{method}', \ 55 | f'/SSD0/guozonghui/project/FFD/Real_video/Frame/{method}') -------------------------------------------------------------------------------- /face_process/lib/ct/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from .detector import RetinaFace 3 | from .utils import * 4 | 5 | 6 | def assert_bounded(val, low, up): 7 | return val >= low and val < up 8 | 9 | 10 | def check_valid(face, w, h): 11 | box = face[0] 12 | if box[0] > box[2]: 13 | return False 14 | if box[1] > box[3]: 15 | return False 16 | for idx, bound in zip([0, 1, 2, 3], [w, h, w, h]): 17 | if not assert_bounded(box[idx], 0, bound): 18 | return False 19 | pts = face[1] 20 | for p in pts: 21 | for idx, bound in zip([0, 1], [w, h]): 22 | if not assert_bounded(p[idx], 0, bound): 23 | return False 24 | return True 25 | 26 | 27 | def post_detect(detect_results, scale, w, h): 28 | new_results = [] 29 | for frame_faces in detect_results: 30 | new_frame_faces = [] 31 | for box, ldm, score in frame_faces: 32 | box = box * scale 33 | ldm = ldm * scale 34 | face = (box, ldm, score) 35 | if check_valid(face, w=w, h=h): 36 | new_frame_faces.append(face) 37 | new_results.append(new_frame_faces) 38 | return new_results 39 | 40 | 41 | class FaceDetector(RetinaFace): 42 | def scale_detect(self, images): 43 | max_res = 1920 44 | h, w = images[0].shape[:2] 45 | if max(h, w) > max_res: 46 | init_scale = max(h, w) / max_res 47 | else: 48 | init_scale = 1 49 | resize_scale = 2 * init_scale 50 | resize_w = int(w / resize_scale) 51 | resize_h = int(h / resize_scale) 52 | detect_input = [cv2.resize(frame, (resize_w, resize_h)) for frame in images] 53 | detect_results = post_detect( 54 | self.detect(detect_input), scale=resize_scale, w=w, h=h, 55 | ) 56 | return detect_results 57 | -------------------------------------------------------------------------------- /face_process/lib/ct/detection/detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .alignment import load_net, batch_detect 7 | 8 | 9 | def get_project_dir(): 10 | current_path = os.path.abspath(os.path.join(__file__, "../")) 11 | return current_path 12 | 13 | 14 | def relative(path): 15 | path = os.path.join(get_project_dir(), path) 16 | return os.path.abspath(path) 17 | 18 | 19 | class RetinaFace: 20 | def __init__( 21 | self, gpu_id=-1, model_path=None, network="mobilenet", 22 | ): 23 | self.gpu_id = gpu_id 24 | self.device = ( 25 | torch.device("cpu") if gpu_id == -1 else torch.device("cuda", gpu_id) 26 | ) 27 | self.model = load_net(model_path, self.device, network) 28 | 29 | def detect(self, images): 30 | if isinstance(images, np.ndarray): 31 | if len(images.shape) == 3: 32 | return batch_detect(self.model, [images], self.device)[0] 33 | elif len(images.shape) == 4: 34 | return batch_detect(self.model, images, self.device) 35 | elif isinstance(images, list): 36 | return batch_detect(self.model, np.array(images), self.device) 37 | elif isinstance(images, torch.Tensor): 38 | if len(images.shape) == 3: 39 | return batch_detect(self.model, images.unsqueeze(0), self.device)[0] 40 | elif len(images.shape) == 4: 41 | return batch_detect(self.model, images, self.device) 42 | else: 43 | raise NotImplementedError() 44 | 45 | def __call__(self, images): 46 | return self.detect(images) 47 | -------------------------------------------------------------------------------- /face_process/lib/ct/face_alignment/__init__.py: -------------------------------------------------------------------------------- 1 | from .predictor import LandmarkPredictor -------------------------------------------------------------------------------- /face_process/lib/ct/face_alignment/basenet.py: -------------------------------------------------------------------------------- 1 | # Backbone networks used for face landmark detection 2 | # Cunjian Chen (cunjian@msu.edu) 3 | 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False): 10 | super(ConvBlock, self).__init__() 11 | self.linear = linear 12 | if dw: 13 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False) 14 | else: 15 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False) 16 | self.bn = nn.BatchNorm2d(oup) 17 | if not linear: 18 | self.prelu = nn.PReLU(oup) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | x = self.bn(x) 23 | if self.linear: 24 | return x 25 | else: 26 | return self.prelu(x) 27 | 28 | 29 | # SE module 30 | # https://github.com/wujiyang/Face_Pytorch/blob/master/backbone/cbam.py 31 | class SEModule(nn.Module): 32 | """Squeeze and Excitation Module""" 33 | 34 | def __init__(self, channels, reduction): 35 | super(SEModule, self).__init__() 36 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 37 | self.fc1 = nn.Conv2d( 38 | channels, channels // reduction, kernel_size=1, padding=0, bias=False 39 | ) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.fc2 = nn.Conv2d( 42 | channels // reduction, channels, kernel_size=1, padding=0, bias=False 43 | ) 44 | self.sigmoid = nn.Sigmoid() 45 | 46 | def forward(self, x): 47 | input = x 48 | x = self.avg_pool(x) 49 | x = self.fc1(x) 50 | x = self.relu(x) 51 | x = self.fc2(x) 52 | x = self.sigmoid(x) 53 | 54 | return input * x 55 | 56 | 57 | # USE global depthwise convolution layer. Compatible with MobileNetV2 (224×224), MobileNetV2_ExternalData (224×224) 58 | class MobileNet_GDConv(nn.Module): 59 | def __init__(self, num_classes): 60 | super(MobileNet_GDConv, self).__init__() 61 | self.pretrain_net = models.mobilenet_v2(pretrained=False) 62 | self.base_net = nn.Sequential(*list(self.pretrain_net.children())[:-1]) 63 | self.linear7 = ConvBlock(1280, 1280, (7, 7), 1, 0, dw=True, linear=True) 64 | self.linear1 = ConvBlock(1280, num_classes, 1, 1, 0, linear=True) 65 | 66 | def forward(self, x): 67 | x = self.base_net(x) 68 | x = self.linear7(x) 69 | x = self.linear1(x) 70 | x = x.view(x.size(0), -1) 71 | return x 72 | 73 | 74 | # USE global depthwise convolution layer. Compatible with MobileNetV2 (56×56) 75 | class MobileNet_GDConv_56(nn.Module): 76 | def __init__(self, num_classes): 77 | super(MobileNet_GDConv_56, self).__init__() 78 | self.pretrain_net = models.mobilenet_v2(pretrained=False) 79 | self.base_net = nn.Sequential(*list(self.pretrain_net.children())[:-1]) 80 | self.linear7 = ConvBlock(1280, 1280, (2, 2), 1, 0, dw=True, linear=True) 81 | self.linear1 = ConvBlock(1280, num_classes, 1, 1, 0, linear=True) 82 | 83 | def forward(self, x): 84 | x = self.base_net(x) 85 | x = self.linear7(x) 86 | x = self.linear1(x) 87 | x = x.view(x.size(0), -1) 88 | return x 89 | 90 | 91 | # MobileNetV2 with SE; Compatible with MobileNetV2_SE (224×224) and MobileNetV2_SE_RE (224×224) 92 | class MobileNet_GDConv_SE(nn.Module): 93 | def __init__(self, num_classes): 94 | super(MobileNet_GDConv_SE, self).__init__() 95 | self.pretrain_net = models.mobilenet_v2(pretrained=True) 96 | self.base_net = nn.Sequential(*list(self.pretrain_net.children())[:-1]) 97 | self.linear7 = ConvBlock(1280, 1280, (7, 7), 1, 0, dw=True, linear=True) 98 | self.linear1 = ConvBlock(1280, num_classes, 1, 1, 0, linear=True) 99 | self.attention = SEModule(1280, 8) 100 | 101 | def forward(self, x): 102 | x = self.base_net(x) 103 | x = self.attention(x) 104 | x = self.linear7(x) 105 | x = self.linear1(x) 106 | x = x.view(x.size(0), -1) 107 | return x 108 | -------------------------------------------------------------------------------- /face_process/lib/ct/face_alignment/predictor.py: -------------------------------------------------------------------------------- 1 | # Face alignment demo 2 | # Uses MTCNN as face detector 3 | # Cunjian Chen (ccunjian@gmail.com) 4 | import torch 5 | import cv2 6 | import numpy as np 7 | from torch.utils.data import DataLoader 8 | from .basenet import MobileNet_GDConv 9 | 10 | 11 | def get_device(gpu_id): 12 | if gpu_id > -1: 13 | return torch.device(f"cuda:{str(gpu_id)}") 14 | else: 15 | return torch.device("cpu") 16 | 17 | 18 | def load_model(file): 19 | model = MobileNet_GDConv(136) 20 | if file is not None: 21 | model.load_state_dict(torch.load(file, map_location="cpu")) 22 | else: 23 | url = "https://github.com/yinglinzheng/face_weights/releases/download/v1/mobilenet_224_model_best_gdconv_external.pth" 24 | model.load_state_dict(torch.utils.model_zoo.load_url(url)) 25 | return model 26 | 27 | 28 | # landmark of (5L, 2L) from [0,1] to real range 29 | def reproject(bbox, landmark): 30 | landmark_ = landmark.clone() 31 | x1, y1, x2, y2 = bbox 32 | w = x2 - x1 33 | h = y2 - y1 34 | landmark_[:, 0] *= w 35 | landmark_[:, 0] += x1 36 | landmark_[:, 1] *= h 37 | landmark_[:, 1] += y1 38 | return landmark_ 39 | 40 | 41 | def prepare_feed(img, face): 42 | height, width, _ = img.shape 43 | mean = np.asarray([0.485, 0.456, 0.406]) 44 | std = np.asarray([0.229, 0.224, 0.225]) 45 | out_size = 224 46 | x1, y1, x2, y2 = face[:4] 47 | 48 | w = x2 - x1 + 1 49 | h = y2 - y1 + 1 50 | size = int(min([w, h]) * 1.2) 51 | cx = x1 + w // 2 52 | cy = y1 + h // 2 53 | x1 = cx - size // 2 54 | x2 = x1 + size 55 | y1 = cy - size // 2 56 | y2 = y1 + size 57 | 58 | dx = max(0, -x1) 59 | dy = max(0, -y1) 60 | x1 = max(0, x1) 61 | y1 = max(0, y1) 62 | 63 | edx = max(0, x2 - width) 64 | edy = max(0, y2 - height) 65 | x2 = min(width, x2) 66 | y2 = min(height, y2) 67 | new_bbox = torch.Tensor([x1, y1, x2, y2]).int() 68 | x1, y1, x2, y2 = new_bbox 69 | cropped = img[y1:y2, x1:x2] 70 | if dx > 0 or dy > 0 or edx > 0 or edy > 0: 71 | cropped = cv2.copyMakeBorder( 72 | cropped, int(dy), int(edy), int(dx), int(edx), cv2.BORDER_CONSTANT, 0 73 | ) 74 | cropped_face = cv2.resize(cropped, (out_size, out_size)) 75 | 76 | if cropped_face.shape[0] <= 0 or cropped_face.shape[1] <= 0: 77 | return None 78 | test_face = cropped_face.copy() 79 | test_face = test_face / 255.0 80 | test_face = (test_face - mean) / std 81 | test_face = test_face.transpose((2, 0, 1)) 82 | data = torch.from_numpy(test_face).float() 83 | return dict(data=data, bbox=new_bbox) 84 | 85 | 86 | @torch.no_grad() 87 | def single_predict(model, feed, device): 88 | landmark = model(feed["data"].unsqueeze(0).to(device)).cpu() 89 | landmark = landmark.reshape(-1, 2) 90 | landmark = reproject(feed["bbox"], landmark) 91 | return landmark.numpy() 92 | 93 | 94 | @torch.no_grad() 95 | def batch_predict(model, feeds, device): 96 | if not isinstance(feeds, list): 97 | feeds = [feeds] 98 | # loader = DataLoader(FeedDataset(feeds), batch_size=50, shuffle=False) 99 | data = [] 100 | for feed in feeds: 101 | data.append(feed["data"].unsqueeze(0)) 102 | data = torch.cat(data, 0).to(device) 103 | results = [] 104 | 105 | landmarks = model(data).cpu() 106 | for landmark, feed in zip(landmarks, feeds): 107 | landmark = landmark.reshape(-1, 2) 108 | landmark = reproject(feed["bbox"], landmark) 109 | results.append(landmark.numpy()) 110 | return results 111 | 112 | 113 | @torch.no_grad() 114 | def batch_predict2(model, feeds, device, batch_size=None): 115 | if not isinstance(feeds, list): 116 | feeds = [feeds] 117 | if batch_size is None: 118 | batch_size = len(feeds) 119 | loader = DataLoader(feeds, batch_size=len(feeds), shuffle=False) 120 | results = [] 121 | for feed in loader: 122 | landmarks = model(feed["data"].to(device)).cpu() 123 | for landmark, bbox in zip(landmarks, feed["bbox"]): 124 | landmark = landmark.reshape(-1, 2) 125 | landmark = reproject(bbox, landmark) 126 | results.append(landmark.numpy()) 127 | return results 128 | 129 | 130 | class LandmarkPredictor: 131 | def __init__(self, gpu_id=0, file=None): 132 | self.device = get_device(gpu_id) 133 | self.model = load_model(file).to(self.device).eval() 134 | 135 | def __call__(self, feeds): 136 | results = batch_predict2(self.model, feeds, self.device) 137 | if not isinstance(feeds, list): 138 | results = results[0] 139 | return results 140 | 141 | @staticmethod 142 | def prepare_feed(img, face): 143 | return prepare_feed(img, face) 144 | -------------------------------------------------------------------------------- /face_process/lib/ct/face_alignment/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def drawLandmark_multiple(img, bbox, landmark): 5 | """ 6 | Input: 7 | - img: gray or RGB 8 | - bbox: type of BBox 9 | - landmark: reproject landmark of (5L, 2L) 10 | Output: 11 | - img marked with landmark and bbox 12 | """ 13 | x1, y1, x2, y2 = bbox 14 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) 15 | for x, y in landmark: 16 | cv2.circle(img, (int(x), int(y)), 2, (0, 255, 0), -1) 17 | return img 18 | -------------------------------------------------------------------------------- /face_process/lib/ct/operations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | from .tracking.sort import iou 7 | 8 | 9 | def face_iou(f1, f2): 10 | return iou(f1[0], f2[0]) 11 | 12 | #FFIW thres=0.2时好使,但是DFDC是0.2时,同一个人脸都可能取不出来,使用0.5能好点fakfkpzenm 13 | def simple_tracking(batch_landmarks, index=0, thres=0.5): 14 | track = [] 15 | for i, faces in enumerate(batch_landmarks): 16 | if i == 0: 17 | if len(faces) <= index or faces[index][-1] < 0.8: 18 | return None 19 | if index != 0: 20 | for idx in range(index): 21 | if face_iou(faces[idx], faces[index]) > thres: 22 | return None 23 | track.append(faces[index]) 24 | else: 25 | last = track[i - 1] 26 | if len(faces) == 0: 27 | return None 28 | sorted_faces = sorted(faces, key=lambda x: face_iou(x, last), reverse=True) 29 | # print(face_iou(sorted_faces[0], last)) 30 | if face_iou(sorted_faces[0], last) < thres: 31 | return None 32 | track.append(sorted_faces[0]) 33 | return track 34 | 35 | 36 | def multiple_tracking(batch_landmarks): 37 | tracks = [] 38 | for i in range(len(batch_landmarks[0])): 39 | # print(i) 40 | track = simple_tracking(batch_landmarks, index=i) 41 | if track is None: 42 | continue 43 | tracks.append(track) 44 | return tracks 45 | 46 | def find_longest(detect_res): 47 | fc = len(detect_res) 48 | tuples = [] 49 | start = 0 50 | end = 0 51 | previous_count = -1 52 | all_tracks = [] 53 | # start 取得到,end 取不到 54 | while start < (fc - 1): 55 | for end in range(start + 2, fc + 1): 56 | tracks = multiple_tracking(detect_res[start:end]) 57 | if (len(tracks) != previous_count and previous_count != -1) or len( 58 | tracks 59 | ) == 0: 60 | break 61 | previous_count = len(tracks) 62 | if end - start > 2: 63 | if end != fc: 64 | un_reach_end = end - 1 65 | else: 66 | un_reach_end = end 67 | sub_tracks = multiple_tracking(detect_res[start:un_reach_end]) 68 | if end == fc and len(sub_tracks) == 0: 69 | un_reach_end = end - 1 70 | sub_tracks = multiple_tracking(detect_res[start:un_reach_end]) 71 | if len(sub_tracks) > 0: 72 | tpl = (start, un_reach_end) 73 | tuples.append(tpl) 74 | all_tracks.append(sub_tracks[0]) 75 | else: 76 | raise NotImplementedError 77 | previous_count = -1 78 | end = un_reach_end 79 | start = end 80 | return tuples, all_tracks -------------------------------------------------------------------------------- /face_process/lib/ct/tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/face_process/lib/ct/tracking/__init__.py -------------------------------------------------------------------------------- /face_process/lib/ct/tracking/tracker.py: -------------------------------------------------------------------------------- 1 | from .sort import Sort 2 | import numpy as np 3 | 4 | 5 | def get_detections(faces): 6 | detections = [] 7 | for face in faces: 8 | x1, y1, x2, y2 = face[0] 9 | detections.append((x1, y1, x2, y2, face[-1])) 10 | return np.array(detections) 11 | 12 | 13 | def get_tracks(detect_results): 14 | tracks = {} 15 | mot_tracker = Sort() 16 | for faces in detect_results: 17 | detections = get_detections(faces) 18 | track_bbs_ids = mot_tracker.update(detections) 19 | for track in track_bbs_ids: # 单独框出每一张人脸 20 | id = int(track[-1]) 21 | box = track[:4] 22 | if id in tracks: 23 | tracks[id].append(box) 24 | else: 25 | tracks[id] = [box] 26 | 27 | return [track for id, track in tracks.items() if len(track) == len(detect_results)] 28 | -------------------------------------------------------------------------------- /face_process/lib/ct/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def write_img(file, img): 5 | cv2.imwrite(file, img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 6 | -------------------------------------------------------------------------------- /face_process/lib/dfdc_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from glob import glob 4 | from pathlib import Path 5 | 6 | 7 | def get_original_video_paths(root_dir, basename=False): 8 | originals = set() 9 | originals_v = set() 10 | for json_path in glob(os.path.join(root_dir, "*/metadata.json")): 11 | dir = Path(json_path).parent 12 | with open(json_path, "r") as f: 13 | metadata = json.load(f) 14 | for k, v in metadata.items(): 15 | original = v.get("original", None) 16 | if v["label"] == "REAL": 17 | original = k 18 | originals_v.add(original) 19 | originals.add(os.path.join(dir, original)) 20 | originals = list(originals) 21 | originals_v = list(originals_v) 22 | return originals_v if basename else originals 23 | 24 | 25 | def get_original_with_fakes(root_dir): 26 | pairs = [] 27 | for json_path in glob(os.path.join(root_dir, "*/metadata.json")): 28 | with open(json_path, "r") as f: 29 | metadata = json.load(f) 30 | for k, v in metadata.items(): 31 | original = v.get("original", None) 32 | if v["label"] == "FAKE": 33 | pairs.append((original[:-4], k[:-4] )) 34 | 35 | return pairs 36 | 37 | 38 | def get_originals_and_fakes(root_dir): 39 | originals = [] 40 | fakes = [] 41 | for json_path in glob(os.path.join(root_dir, "metadata.json")): 42 | with open(json_path, "r") as f: 43 | metadata = json.load(f) 44 | # i = 0 45 | for k, v in metadata.items(): 46 | if v["label"] == "FAKE": 47 | fakes.append([k[:-4], 1, v["split"], v["original"][:-4]]) 48 | else: 49 | originals.append([k[:-4], 0, v["split"], "none" ]) 50 | # i += 1 51 | # if i > 1: 52 | # break 53 | return originals, fakes 54 | -------------------------------------------------------------------------------- /face_process/lib/shape_predictor_81_face_landmarks.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/face_process/lib/shape_predictor_81_face_landmarks.dat -------------------------------------------------------------------------------- /face_process/lib/xray/faster_crop_align_xray.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from lib.xray.warp_for_xray import ( 4 | estimiate_batch_transform, 5 | transform_landmarks, 6 | std_points_256, 7 | ) 8 | import numpy as np 9 | 10 | 11 | class FasterCropAlignXRay: 12 | """ 13 | 修正到统一坐标系,统一图像大小到标准尺寸 14 | """ 15 | 16 | def __init__(self, size=256): 17 | self.image_size = size 18 | self.std_points = std_points_256 * size / 256.0 19 | 20 | def __call__(self, landmarks, images=None, jitter=False): 21 | landmarks = [landmark[:4] for landmark in landmarks] 22 | ori_boxes = np.array([ori_box for _, _, _, ori_box in landmarks]) 23 | five_landmarks = np.array([ldm5 for _, ldm5, _, _ in landmarks]) 24 | landmarks68 = np.array([ldm68 for _, _, ldm68, _ in landmarks]) 25 | # assert landmarks68.min() > 0 26 | 27 | left_top = ori_boxes[:, :2].min(0) 28 | 29 | right_bottom = ori_boxes[:, 2:].max(0) 30 | 31 | size = right_bottom - left_top 32 | 33 | w, h = size 34 | 35 | diff = ori_boxes[:, :2] - left_top[None, ...] 36 | 37 | new_five_landmarks = five_landmarks + diff[:, None, :] 38 | new_landmarks68 = landmarks68 + diff[:, None, :] 39 | 40 | landmark_for_estimiate = new_five_landmarks.copy() 41 | if jitter: 42 | landmark_for_estimiate += np.random.uniform( 43 | -4, 4, landmark_for_estimiate.shape 44 | ) 45 | 46 | tfm, trans = estimiate_batch_transform( 47 | landmark_for_estimiate, tgt_pts=self.std_points 48 | ) 49 | 50 | transformed_landmarks68 = np.array( 51 | [transform_landmarks(ldm68, trans) for ldm68 in new_landmarks68] 52 | ) 53 | transformed_landmarks5 = np.array( 54 | [transform_landmarks(ldm68, trans) for ldm68 in new_five_landmarks] 55 | ) 56 | 57 | if images is not None: 58 | transformed_images = [ 59 | self.process_sinlge(tfm, image, d, h, w) 60 | for image, d in zip(images, diff) 61 | ] # 拼接 func 的参数 62 | transformed_images = np.stack(transformed_images) 63 | return transformed_landmarks68, transformed_images,transformed_landmarks5 64 | else: 65 | return transformed_landmarks68 66 | 67 | def process_sinlge(self, tfm, image, d, h, w): 68 | assert isinstance(image, np.ndarray) 69 | new_image = np.zeros((h, w, 3), dtype=np.uint8) 70 | x, y = d 71 | ih, iw, _ = image.shape 72 | new_image[y : y + ih, x : x + iw] = image 73 | transformed_image = cv2.warpAffine( 74 | new_image, tfm, (self.image_size, self.image_size) 75 | ) 76 | return transformed_image 77 | -------------------------------------------------------------------------------- /face_process/real_face_process_Frame.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from face_process.lib.ct.detection import FaceDetector 3 | import cv2 4 | from face_process.lib.utils import flatten,partition 5 | from tqdm import tqdm 6 | import face_process.face_utils as face_utils 7 | import os 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 9 | detector = FaceDetector(0) 10 | 11 | def split_array(array, segment_size): 12 | segmented_array = [] 13 | for i in range(0, len(array), segment_size): 14 | segment = array[i:i+segment_size] 15 | segmented_array.append(segment) 16 | return segmented_array 17 | 18 | def process(frame): 19 | frame=cv2.imread(frame) 20 | frames = [] 21 | frames.append(frame) 22 | detect_res = flatten( 23 | [detector.detect(item) for item in partition(frames, 1)] 24 | ) 25 | detect_res = get_valid_faces(detect_res, thres=0.5) 26 | for faces, frame in zip(detect_res, frames): 27 | if len(faces) > 0: 28 | bbox, lm5, score = faces[0] 29 | frame, landmark, bbox=face_utils.crop_aligned(frame,lm5,landmarks_68=None,bboxes=bbox,aligned_image_size=224) 30 | bbox = np.array([[bbox[0],bbox[1]],[bbox[2],bbox[3]]]) 31 | frame_croped = crop_face_sbi(frame,bbox,margin=False,crop_by_bbox=True,abs_coord=True,phase='test') 32 | frame_croped = cv2.resize(frame_croped,(224,224),interpolation=cv2.INTER_LINEAR) 33 | return frame_croped 34 | 35 | def get_valid_faces(detect_results, max_count=10, thres=0.5, at_least=False): 36 | new_results = [] 37 | for i, faces in enumerate(detect_results): 38 | # faces = sorted(faces, key=lambda x: bbox_range(x[0]), reverse=True) 39 | # print(len(faces)) 40 | # assert 0 41 | if len(faces) > max_count: 42 | faces = faces[:max_count] 43 | l = [] 44 | for j, face in enumerate(faces): 45 | if face[-1] < thres and not (j == 0 and at_least): 46 | continue 47 | box, lm, score = face 48 | box = box.astype(np.float) 49 | lm = lm.astype(np.float) 50 | l.append((box, lm, score)) 51 | new_results.append(l) 52 | return new_results 53 | 54 | 55 | def crop_face_sbi(img,bbox=None,margin=False,crop_by_bbox=True,abs_coord=False,only_img=False,phase='train'): 56 | assert phase in ['train','val','test'] 57 | 58 | #crop face------------------------------------------ 59 | H,W=len(img),len(img[0]) 60 | 61 | if crop_by_bbox: 62 | x0,y0=bbox[0] 63 | x1,y1=bbox[1] 64 | w=x1-x0 65 | h=y1-y0 66 | w0_margin=w/4#0#np.random.rand()*(w/8) 67 | w1_margin=w/4 68 | h0_margin=h/4#0#np.random.rand()*(h/5) 69 | h1_margin=h/4 70 | 71 | 72 | 73 | 74 | if margin: 75 | w0_margin*=4 76 | w1_margin*=4 77 | h0_margin*=2 78 | h1_margin*=2 79 | elif phase=='train': 80 | w0_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 81 | w1_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 82 | h0_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 83 | h1_margin*=(np.random.rand()*0.6+0.2)#np.random.rand() 84 | else: 85 | w0_margin*=0.5 86 | w1_margin*=0.5 87 | h0_margin*=0.5 88 | h1_margin*=0.5 89 | 90 | y0_new=max(0,int(y0-h0_margin)) 91 | y1_new=min(H,int(y1+h1_margin)+1) 92 | x0_new=max(0,int(x0-w0_margin)) 93 | x1_new=min(W,int(x1+w1_margin)+1) 94 | 95 | img_cropped=img[y0_new:y1_new,x0_new:x1_new] 96 | 97 | return img_cropped 98 | 99 | 100 | -------------------------------------------------------------------------------- /face_process/sample_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | 5 | def sample_images(source_path, target_path, num_samples): 6 | # 获取源路径下所有的图像文件 7 | image_files = [f for f in os.listdir(source_path) if os.path.isfile(os.path.join(source_path, f)) 8 | and f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.mp4'))] 9 | # image_files = get_DFDC(0) 10 | # print(image_files) 11 | # 随机采样指定数量的图像文件 12 | sampled_images = random.sample(image_files, min(num_samples, len(image_files))) 13 | i = 0 14 | # 复制采样到的图像文件到目标路径 15 | for image in sampled_images: 16 | source_image_path = os.path.join(source_path, image) 17 | # target_image_path = os.path.join(target_path, image) 18 | target_image_path = os.path.join(target_path, str(1000+i)+'.mp4') 19 | 20 | shutil.copyfile(source_image_path, target_image_path) 21 | print(f"复制文件: {image} 到 {target_path}") 22 | i +=1 23 | def find_mp4_files(root_dir='/SSD0/guozonghui/project/datasets/rawdata/FF++/manipulated_sequences/DeepFakeDetection/c23'): 24 | mp4_files = [] 25 | for root, dirs, files in os.walk(root_dir): 26 | for file in files: 27 | if file.endswith(".mp4"): 28 | mp4_files.append(os.path.join(root, file)) 29 | print(len(mp4_files)) 30 | return mp4_files 31 | 32 | 33 | def sample_images_v2(source_path, target_path, num_samples): 34 | image_files = find_mp4_files() 35 | # print(len(image_files)) 36 | # assert 0 37 | sampled_images = random.sample(image_files, min(num_samples, len(image_files))) 38 | for i, image in enumerate(sampled_images): 39 | target_image_path = os.path.join(target_path, str(1000+i)+'.mp4') 40 | print(f"复制文件: {image} 到 {target_path}") 41 | shutil.copyfile(image, target_image_path) 42 | 43 | 44 | import pandas as pd 45 | def get_DFDC(choose_label=1): 46 | file_list = [] 47 | root = '/SSD0/guozonghui/project/datasets/rawdata/DFDC/test_videos/' 48 | label=pd.read_csv('/SSD0/guozonghui/project/datasets/rawdata/DFDC/labels.csv',delimiter=',') 49 | dataset_info = [(video_name[:-4], label) for video_name, label in zip(label['filename'].tolist(), label['label'].tolist())] 50 | filtered_video_names = [video_name for video_name, label in dataset_info if label == choose_label] 51 | for i in range(len(filtered_video_names)): 52 | file_list.append(filtered_video_names[i]+'.mp4') 53 | return file_list 54 | # dfdc_real = get_DFDC(0) 55 | # dfdc_fake = get_DFDC(1) 56 | 57 | import os 58 | 59 | 60 | 61 | 62 | 63 | 64 | # # # 指定源路径、目标路径和要采样的图像数量 65 | source_path = "/SSD0/guozonghui/project/datasets/rawdata/FFIW10K-v1-release/source/train" 66 | target_path = "/SSD0/guozonghui/project/FFD/ffd_video_data/real/FFIW" 67 | num_samples = 2500 68 | 69 | # # 执行图像采样和复制 70 | sample_images(source_path, target_path, num_samples) 71 | # sample_images_v2(source_path, target_path, num_samples) 72 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import models.custom_ssl as custom_ssl 2 | import torch 3 | from PIL import Image 4 | import albumentations as alb 5 | from albumentations.pytorch.transforms import ToTensorV2 6 | import numpy as np 7 | import csv 8 | import glob 9 | from tqdm import tqdm 10 | import os 11 | from face_process.real_face_process_Frame import process 12 | 13 | def get_model(): 14 | model = custom_ssl.BEiT_v2(pretrained=False) 15 | ckpt_load_path = 'checkpoints/Final_DDBF_BEiT_v2/ckpt/Final-mainbranch.tar' 16 | checkpoint = torch.load(ckpt_load_path, map_location='cpu') 17 | if 'state_dict' in checkpoint: 18 | sd = checkpoint['state_dict'] 19 | else: 20 | sd = checkpoint 21 | new_state_dict = {} 22 | for k, v in sd.items(): 23 | if k.startswith('module.'): 24 | k = k.replace('module.', '') 25 | new_state_dict[k] = v 26 | msg = model.load_state_dict(new_state_dict,strict=False) 27 | print('sdload', msg) 28 | 29 | return model 30 | 31 | def Inference_Img(model, path, transfrom): 32 | img = process(path) 33 | img = np.asarray(img) 34 | tmp_imgs = {"image": img} 35 | input_tensor = transfrom(**tmp_imgs) 36 | input_tensor = input_tensor['image'].cuda().unsqueeze(0) 37 | input_tensor = input_tensor.unsqueeze(1) 38 | output = model(input_tensor).squeeze(1) 39 | pred = torch.nn.functional.softmax(output, dim=1)[:,1] 40 | return pred 41 | 42 | 43 | 44 | if __name__ == "__main__": 45 | model = get_model() 46 | model = model.cuda() 47 | model.eval() 48 | additional_targets = {} 49 | base_transform = alb.Compose([ 50 | alb.Resize(224, 224), 51 | alb.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 52 | ToTensorV2(), 53 | ], additional_targets=additional_targets) 54 | 55 | csv_file = 'inference_results.csv' 56 | with open(csv_file, 'w', newline='') as file: 57 | writer = csv.writer(file) 58 | writer.writerow(["img_name", "y_pred"]) 59 | while True: 60 | path = input("Enter the path of the image (e.g., 'test.jpg') or 'exit' to stop: ") 61 | if path.lower() == 'exit': 62 | break 63 | elif os.path.splitext(path)[1].lower() =='.jpg' or os.path.splitext(path)[1].lower() =='.png': 64 | pred = Inference_Img(model, path, base_transform) 65 | writer.writerow([path, pred.cpu().detach().numpy()[0]]) 66 | print(path, 'Fake Score:', pred.cpu().detach().numpy()[0]) 67 | else: 68 | print('Inference all images in:', path) 69 | paths = glob.glob(path + '/*.jpg', recursive=True)+glob.glob(path + '/*.png', recursive=True) 70 | for path in tqdm(paths): 71 | pred = Inference_Img(model, path, base_transform) 72 | writer.writerow([path, pred.cpu().detach().numpy()[0]]) 73 | print('result saved in :', csv_file) 74 | break 75 | -------------------------------------------------------------------------------- /inference_results.csv: -------------------------------------------------------------------------------- 1 | img_name,y_pred 2 | test_img/Fake_ori.png,0.99992335 3 | test_img/Fake2.png,0.9996152 4 | test_img/Fake1.png,0.9998584 5 | test_img/Fake3.png,0.9994948 6 | test_img/Real2.png,0.06445194 7 | test_img/Real1.png,0.006921941 8 | test_img/Fake4.png,0.9999298 9 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom import * 2 | from .custom_sl import * 3 | from .custom_ssl import * 4 | 5 | 6 | -------------------------------------------------------------------------------- /models/custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.lib.BEiT_v2 import modeling_finetune 5 | 6 | 7 | class DDBF_BEiT_v2(nn.Module): 8 | def __init__(self, num_class=2,**kwargs): 9 | super(DDBF_BEiT_v2,self).__init__() 10 | self.num_class = num_class 11 | feature_dim = 768 12 | self.backbone_model = modeling_finetune.beit_base_patch16_224(**BEiT_Config) 13 | self.backbone_model2 = modeling_finetune.beit_base_patch16_224(**BEiT_Config) 14 | self.backbone_model.head = nn.Identity() 15 | self.backbone_model2.head = nn.Identity() 16 | pretrain_path = kwargs['pretrained_path'] 17 | checkpoint = torch.load(pretrain_path, map_location='cpu') 18 | checkpoint_model = checkpoint['model'] # module 19 | msg1 = self.backbone_model.load_state_dict(checkpoint_model, strict=False) 20 | msg2 = self.backbone_model2.load_state_dict(checkpoint_model, strict=False) 21 | print('load info', msg1) 22 | print('load info', msg2) 23 | self.norm_for_cls = nn.LayerNorm(feature_dim*1) 24 | self.header = nn.Sequential(nn.Linear(feature_dim*1, num_class)) 25 | def forward(self, x): 26 | B, T, C, H, W = x.size() 27 | x = x.flatten(0,1) 28 | clstoken = self.backbone_model(x) 29 | clstoken2 = self.backbone_model2(x) 30 | pearson_loss = pearson_correlation_loss(clstoken.detach(), clstoken2) 31 | pred1 = self.header(self.norm_for_cls(clstoken)) 32 | pred2 = self.header(self.norm_for_cls(clstoken2)) 33 | pred1_un = uncertainty(pred1) 34 | pred2_un = uncertainty(pred2) 35 | weight = (-10)*torch.cat([pred1_un,pred2_un], dim=-1) 36 | w = torch.nn.functional.softmax(weight, dim=-1) 37 | weighted_output = torch.matmul(w.unsqueeze(1), torch.stack((clstoken, clstoken2), dim=1)).squeeze(1) 38 | pred = self.header(self.norm_for_cls(weighted_output)) 39 | output = pred.view(B, T, -1) 40 | output1 = pred1.view(B, T, -1) 41 | output2 = pred2.view(B, T, -1) 42 | return output, output1, output2, pearson_loss 43 | 44 | BEiT_Config = { 45 | 'drop_path_rate': 0.1, 46 | 'use_mean_pooling': False, 47 | 'init_values': 0.1, 48 | 'qkv_bias': True, 49 | 'use_abs_pos_emb': False, 50 | 'use_rel_pos_bias': True, 51 | 'use_shared_rel_pos_bias': False 52 | } 53 | 54 | 55 | def pearson_correlation_loss(x, y): 56 | mean_x = torch.mean(x, dim=-1,keepdim=True) #[32,1] 57 | mean_y = torch.mean(y, dim=-1,keepdim=True) #[32,1] 58 | xm = x.sub(mean_x) 59 | ym = y.sub(mean_y) 60 | x_u = xm.unsqueeze(1) #[B,1,768] 61 | y_u = ym.unsqueeze(2) #[B,768,1] 62 | Conv_xy = torch.bmm(x_u, y_u).squeeze()/x.size(-1) #[B] 63 | std_x = torch.std(x, dim=-1) #[B] 64 | std_y = torch.std(y, dim=-1) #[B] 65 | pearson_cor = Conv_xy/(std_x*std_y) 66 | pearson_cor = pearson_cor.mean() 67 | r = torch.clamp(pearson_cor, -1.0, 1.0) 68 | return r 69 | 70 | 71 | def exp_evidence(y): 72 | return torch.exp(torch.clamp(y, -10, 10)) 73 | 74 | 75 | def uncertainty(output): 76 | evidence = exp_evidence(output) 77 | alpha = evidence + 1 78 | uncertainty = 2 / torch.sum(alpha, dim=1, keepdim=True) 79 | return uncertainty 80 | -------------------------------------------------------------------------------- /models/lib/BEiT_v3/modeling_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3 4 | # Copyright (c) 2023 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # --------------------------------------------------------' 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 12 | 13 | from torchscale.model.BEiT3 import BEiT3 14 | from torchscale.architecture.config import EncoderConfig 15 | 16 | 17 | def trunc_normal_(tensor, mean=0., std=1.): 18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 19 | 20 | 21 | def _get_base_config( 22 | img_size=224, patch_size=16, drop_path_rate=0, 23 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs 24 | ): 25 | return EncoderConfig( 26 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 27 | layernorm_embedding=False, normalize_output=True, no_output_layer=True, 28 | drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12, 29 | encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12, 30 | checkpoint_activations=checkpoint_activations, 31 | ) 32 | 33 | 34 | def _get_large_config( 35 | img_size=224, patch_size=16, drop_path_rate=0, 36 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs 37 | ): 38 | return EncoderConfig( 39 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 40 | layernorm_embedding=False, normalize_output=True, no_output_layer=True, 41 | drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16, 42 | encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24, 43 | checkpoint_activations=checkpoint_activations, 44 | ) 45 | 46 | 47 | class BEiT3Wrapper(nn.Module): 48 | def __init__(self, args, **kwargs): 49 | super().__init__() 50 | self.args = args 51 | self.beit3 = BEiT3(args) 52 | self.apply(self._init_weights) 53 | 54 | def fix_init_weight(self): 55 | def rescale(param, layer_id): 56 | param.div_(math.sqrt(2.0 * layer_id)) 57 | 58 | for layer_id, layer in enumerate(self.blocks): 59 | rescale(layer.attn.proj.weight.data, layer_id + 1) 60 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 61 | 62 | def get_num_layers(self): 63 | return self.beit3.encoder.num_layers 64 | 65 | @torch.jit.ignore 66 | def no_weight_decay(self): 67 | return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'} 68 | 69 | def _init_weights(self, m): 70 | if isinstance(m, nn.Linear): 71 | trunc_normal_(m.weight, std=.02) 72 | if isinstance(m, nn.Linear) and m.bias is not None: 73 | nn.init.constant_(m.bias, 0) 74 | elif isinstance(m, nn.LayerNorm): 75 | nn.init.constant_(m.bias, 0) 76 | nn.init.constant_(m.weight, 1.0) 77 | -------------------------------------------------------------------------------- /models/lib/MAE/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /models/lib/MAE/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /models/lib/MAE/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /models/lib/MAE/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /models/lib/MAE/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /models/lib/MAE/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /models/lib/MoCoV3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/lib/MoCoV3/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class MoCo(nn.Module): 12 | """ 13 | Build a MoCo model with a base encoder, a momentum encoder, and two MLPs 14 | https://arxiv.org/abs/1911.05722 15 | """ 16 | def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0): 17 | """ 18 | dim: feature dimension (default: 256) 19 | mlp_dim: hidden dimension in MLPs (default: 4096) 20 | T: softmax temperature (default: 1.0) 21 | """ 22 | super(MoCo, self).__init__() 23 | 24 | self.T = T 25 | 26 | # build encoders 27 | # self.base_encoder = base_encoder(num_classes=mlp_dim) 28 | # self.momentum_encoder = base_encoder(num_classes=mlp_dim) 29 | 30 | self.base_encoder = base_encoder() 31 | self.momentum_encoder = base_encoder() 32 | 33 | self._build_projector_and_predictor_mlps(dim, mlp_dim) 34 | 35 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 36 | param_m.data.copy_(param_b.data) # initialize 37 | param_m.requires_grad = False # not update by gradient 38 | 39 | def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 40 | mlp = [] 41 | for l in range(num_layers): 42 | dim1 = input_dim if l == 0 else mlp_dim 43 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 44 | 45 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 46 | 47 | if l < num_layers - 1: 48 | mlp.append(nn.BatchNorm1d(dim2)) 49 | mlp.append(nn.ReLU(inplace=True)) 50 | elif last_bn: 51 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 52 | # for simplicity, we further removed gamma in BN 53 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 54 | 55 | return nn.Sequential(*mlp) 56 | 57 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim): 58 | pass 59 | 60 | @torch.no_grad() 61 | def _update_momentum_encoder(self, m): 62 | """Momentum update of the momentum encoder""" 63 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 64 | param_m.data = param_m.data * m + param_b.data * (1. - m) 65 | 66 | def contrastive_loss(self, q, k): 67 | # normalize 68 | q = nn.functional.normalize(q, dim=1) 69 | k = nn.functional.normalize(k, dim=1) 70 | # gather all targets 71 | k = concat_all_gather(k) 72 | # Einstein sum is more intuitive 73 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T 74 | N = logits.shape[0] # batch size per GPU 75 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() 76 | return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) 77 | 78 | def forward(self, x1, x2, m): 79 | """ 80 | Input: 81 | x1: first views of images 82 | x2: second views of images 83 | m: moco momentum 84 | Output: 85 | loss 86 | """ 87 | 88 | # compute features 89 | q1 = self.predictor(self.base_encoder(x1)) 90 | q2 = self.predictor(self.base_encoder(x2)) 91 | 92 | with torch.no_grad(): # no gradient 93 | self._update_momentum_encoder(m) # update the momentum encoder 94 | 95 | # compute momentum features as targets 96 | k1 = self.momentum_encoder(x1) 97 | k2 = self.momentum_encoder(x2) 98 | 99 | return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1) 100 | 101 | 102 | class MoCo_ResNet(MoCo): 103 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim): 104 | hidden_dim = self.base_encoder.fc.weight.shape[1] 105 | del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer 106 | 107 | # projectors 108 | self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) 109 | self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) 110 | 111 | # predictor 112 | self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False) 113 | 114 | 115 | class MoCo_ViT(MoCo): 116 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim): 117 | hidden_dim = self.base_encoder.head.weight.shape[1] 118 | del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer 119 | 120 | # projectors 121 | self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) 122 | self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) 123 | 124 | # predictor 125 | self.predictor = self._build_mlp(2, dim, mlp_dim, dim) 126 | 127 | 128 | # utils 129 | @torch.no_grad() 130 | def concat_all_gather(tensor): 131 | """ 132 | Performs all_gather operation on the provided tensors. 133 | *** Warning ***: torch.distributed.all_gather has no gradient. 134 | """ 135 | tensors_gather = [torch.ones_like(tensor) 136 | for _ in range(torch.distributed.get_world_size())] 137 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 138 | 139 | output = torch.cat(tensors_gather, dim=0) 140 | return output 141 | -------------------------------------------------------------------------------- /models/lib/MoCoV3/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from PIL import Image, ImageFilter, ImageOps 8 | import math 9 | import random 10 | import torchvision.transforms.functional as tf 11 | 12 | 13 | class TwoCropsTransform: 14 | """Take two random crops of one image""" 15 | 16 | def __init__(self, base_transform1, base_transform2): 17 | self.base_transform1 = base_transform1 18 | self.base_transform2 = base_transform2 19 | 20 | def __call__(self, x): 21 | im1 = self.base_transform1(x) 22 | im2 = self.base_transform2(x) 23 | return [im1, im2] 24 | 25 | 26 | class GaussianBlur(object): 27 | """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" 28 | 29 | def __init__(self, sigma=[.1, 2.]): 30 | self.sigma = sigma 31 | 32 | def __call__(self, x): 33 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 34 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 35 | return x 36 | 37 | 38 | class Solarize(object): 39 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 40 | 41 | def __call__(self, x): 42 | return ImageOps.solarize(x) -------------------------------------------------------------------------------- /models/lib/MoCoV3/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | class LARS(torch.optim.Optimizer): 11 | """ 12 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 13 | """ 14 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 15 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 16 | super().__init__(params, defaults) 17 | 18 | @torch.no_grad() 19 | def step(self): 20 | for g in self.param_groups: 21 | for p in g['params']: 22 | dp = p.grad 23 | 24 | if dp is None: 25 | continue 26 | 27 | if p.ndim > 1: # if not normalization gamma/beta or bias 28 | dp = dp.add(p, alpha=g['weight_decay']) 29 | param_norm = torch.norm(p) 30 | update_norm = torch.norm(dp) 31 | one = torch.ones_like(param_norm) 32 | q = torch.where(param_norm > 0., 33 | torch.where(update_norm > 0, 34 | (g['trust_coefficient'] * param_norm / update_norm), one), 35 | one) 36 | dp = dp.mul(q) 37 | 38 | param_state = self.state[p] 39 | if 'mu' not in param_state: 40 | param_state['mu'] = torch.zeros_like(p) 41 | mu = param_state['mu'] 42 | mu.mul_(g['momentum']).add_(dp) 43 | p.add_(mu, alpha=-g['lr']) 44 | -------------------------------------------------------------------------------- /models/lib/SIMMIM/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /models/lib/SIMMIM/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | from .swin_transformer import build_swin 10 | from .vision_transformer import build_vit 11 | from .simmim import build_simmim 12 | 13 | 14 | def build_model(config, is_pretrain=True): 15 | if is_pretrain: 16 | model = build_simmim(config) 17 | else: 18 | model_type = config.MODEL.TYPE 19 | if model_type == 'swin': 20 | model = build_swin(config) 21 | elif model_type == 'vit': 22 | model = build_vit(config) 23 | else: 24 | raise NotImplementedError(f"Unknown fine-tune model: {model_type}") 25 | 26 | return model 27 | -------------------------------------------------------------------------------- /models/lib/SIMMIM/main_finetune.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 20 | from timm.utils import accuracy, AverageMeter 21 | 22 | from config import get_config 23 | from models import build_model 24 | 25 | 26 | 27 | def parse_option(): 28 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 29 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 30 | parser.add_argument( 31 | "--opts", 32 | help="Modify config options by adding 'KEY VALUE' pairs. ", 33 | default=None, 34 | nargs='+', 35 | ) 36 | 37 | # easy config modification 38 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 39 | parser.add_argument('--data-path', type=str, help='path to dataset') 40 | parser.add_argument('--pretrained', type=str, help='path to pre-trained model') 41 | parser.add_argument('--resume', help='resume from checkpoint') 42 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 43 | parser.add_argument('--use-checkpoint', action='store_true', 44 | help="whether to use gradient checkpointing to save memory") 45 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 46 | help='mixed precision opt level, if O0, no amp is used') 47 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 48 | help='root of output folder, the full path is // (default: output)') 49 | parser.add_argument('--tag', help='tag of experiment') 50 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 51 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 52 | 53 | # distributed training 54 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 55 | 56 | args = parser.parse_args() 57 | 58 | config = get_config(args) 59 | 60 | return args, config 61 | -------------------------------------------------------------------------------- /models/lib/SIMMIM/simmim_finetune__vit_base__img224__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.1 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: True 11 | USE_SHARED_RPB: False 12 | USE_MEAN_POOLING: True 13 | DATA: 14 | IMG_SIZE: 224 15 | TRAIN: 16 | EPOCHS: 100 17 | WARMUP_EPOCHS: 20 18 | BASE_LR: 1.25e-3 19 | WARMUP_LR: 2.5e-7 20 | MIN_LR: 2.5e-7 21 | WEIGHT_DECAY: 0.05 22 | LAYER_DECAY: 0.65 23 | PRINT_FREQ: 100 24 | SAVE_FREQ: 5 25 | TAG: simmim_finetune__vit_base__img224__800ep 26 | -------------------------------------------------------------------------------- /models/lib/obow/ResNet50_OBoW_full.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters. 2 | model: 3 | alpha: 0.99 4 | alpha_cosine: True 5 | feature_extractor_arch: "resnet50" 6 | feature_extractor_opts: 7 | global_pooling: True 8 | # Use two feature levels for BoW: "block3" (aka conv4 of ResNet) and "block4" 9 | # (aka conv5 of ResNet). 10 | bow_levels: ["block3", "block4"] 11 | bow_extractor_opts: 12 | inv_delta: 15 13 | num_words: 8192 14 | bow_predictor_opts: 15 | kappa: 8 16 | # (Optional) on-line learning of a linear classifier on top of teacher 17 | # features for monitoring purposes. 18 | num_classes: 1000 19 | 20 | # Optimization parameters. 21 | optim: 22 | optim_type: "sgd" 23 | momentum: 0.9 24 | weight_decay: 0.0001 25 | nesterov: False 26 | num_epochs: 200 27 | lr: 0.03 28 | end_lr: 0.00003 29 | lr_schedule_type: "cos_warmup" 30 | warmup_epochs: 10 31 | permanent: 10 # save a permanent checkpoint every 10 epochs. 32 | 33 | # Data parameters: 34 | data: 35 | dataset_name: "ImageNet" 36 | batch_size: 256 37 | epoch_size: 38 | subset: 39 | cjitter: [0.4, 0.4, 0.4, 0.1] 40 | cjitter_p: 0.8 41 | gray_p: 0.2 42 | gaussian_blur: [0.1, 2.0] 43 | gaussian_blur_p: 0.5 44 | num_img_crops: 2 # 2 crops of size 160x160. 45 | image_crop_size: 160 46 | image_crop_range: [0.08, 0.6] 47 | num_img_patches: 5 # 5 patches of size 96x96. 48 | img_patch_preresize: 256 49 | img_patch_preresize_range: [0.6, 1.0] 50 | img_patch_size: 96 51 | img_patch_jitter: 24 52 | only_patches: False 53 | -------------------------------------------------------------------------------- /models/lib/obow/__init__.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | 4 | project_root = pathlib.Path(__file__).resolve().parents[1] 5 | -------------------------------------------------------------------------------- /models/lib/obow/fewshot.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import models.lib.obow.utils as utils 6 | 7 | 8 | def preprocess_5D_features(features, global_pooling): 9 | meta_batch_size, num_examples, channels, height, width = features.size() 10 | features = features.view( 11 | meta_batch_size * num_examples, channels, height, width) 12 | 13 | if global_pooling: 14 | features = utils.global_pooling(features, "avg") 15 | 16 | features = features.view(meta_batch_size, num_examples, -1) 17 | 18 | return features 19 | 20 | 21 | def average_train_features(features_train, labels_train): 22 | labels_train_transposed = labels_train.transpose(1,2) 23 | weight_novel = torch.bmm(labels_train_transposed, features_train) 24 | weight_novel = weight_novel.div( 25 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as( 26 | weight_novel)) 27 | 28 | return weight_novel 29 | 30 | 31 | def few_shot_classifier_with_prototypes( 32 | features_test, features_train, labels_train, 33 | scale_cls=10.0, global_pooling=True): 34 | 35 | #******* Generate classification weights for the novel categories ****** 36 | if features_train.dim() == 5: 37 | features_train = preprocess_5D_features(features_train, global_pooling) 38 | features_test = preprocess_5D_features(features_test, global_pooling) 39 | 40 | assert features_train.dim() == 3 41 | assert features_test.dim() == 3 42 | 43 | meta_batch_size = features_train.size(0) 44 | num_novel = labels_train.size(2) 45 | features_train = F.normalize(features_train, p=2, dim=2) 46 | prototypes = average_train_features(features_train, labels_train) 47 | prototypes = prototypes.view(meta_batch_size, num_novel, -1) 48 | #*********************************************************************** 49 | features_test = F.normalize(features_test, p=2, dim=2) 50 | prototypes = F.normalize(prototypes, p=2, dim=2) 51 | scores = scale_cls * torch.bmm(features_test, prototypes.transpose(1,2)) 52 | 53 | return scores 54 | 55 | 56 | def few_shot_feature_classification( 57 | classifier, features_test, features_train, labels_train_1hot, labels_test): 58 | 59 | scores = few_shot_classifier_with_prototypes( 60 | features_test=features_test, 61 | features_train=features_train, 62 | labels_train=labels_train_1hot) 63 | 64 | assert scores.dim() == 3 65 | 66 | scores = scores.view(scores.size(0) * scores.size(1), -1) 67 | labels_test = labels_test.view(-1) 68 | assert scores.size(0) == labels_test.size(0) 69 | 70 | loss = F.cross_entropy(scores, labels_test) 71 | 72 | with torch.no_grad(): 73 | accuracy = utils.accuracy(scores, labels_test, topk=(1,)) 74 | 75 | return scores, loss, accuracy 76 | 77 | 78 | @torch.no_grad() 79 | def fewshot_classification( 80 | feature_extractor, 81 | images_train, 82 | labels_train, 83 | labels_train_1hot, 84 | images_test, 85 | labels_test, 86 | feature_levels): 87 | assert images_train.dim() == 5 88 | assert images_test.dim() == 5 89 | assert images_train.size(0) == images_test.size(0) 90 | assert images_train.size(2) == images_test.size(2) 91 | assert images_train.size(3) == images_test.size(3) 92 | assert images_train.size(4) == images_test.size(4) 93 | assert labels_train.dim() == 2 94 | assert labels_test.dim() == 2 95 | assert labels_train.size(0) == labels_test.size(0) 96 | assert labels_train.size(0) == images_train.size(0) 97 | assert (feature_levels is None) or isinstance(feature_levels, (list, tuple)) 98 | meta_batch_size = images_train.size(0) 99 | 100 | images_train = utils.convert_from_5d_to_4d(images_train) 101 | images_test = utils.convert_from_5d_to_4d(images_test) 102 | labels_test = labels_test.view(-1) 103 | batch_size_train = images_train.size(0) 104 | images = torch.cat([images_train, images_test], dim=0) 105 | 106 | # Extract features from the train and test images. 107 | features = feature_extractor(images, feature_levels) 108 | if isinstance(features, torch.Tensor): 109 | features = [features,] 110 | 111 | labels_test =labels_test.view(-1) 112 | 113 | loss, accuracy = [], [] 114 | for i, features_i in enumerate(features): 115 | features_train = features_i[:batch_size_train] 116 | features_test = features_i[batch_size_train:] 117 | features_train = utils.add_dimension(features_train, meta_batch_size) 118 | features_test = utils.add_dimension(features_test, meta_batch_size) 119 | 120 | scores = few_shot_classifier_with_prototypes( 121 | features_test, features_train, labels_train_1hot, 122 | scale_cls=10.0, global_pooling=True) 123 | 124 | scores = scores.view(scores.size(0) * scores.size(1), -1) 125 | assert scores.size(0) == labels_test.size(0) 126 | loss.append(F.cross_entropy(scores, labels_test)) 127 | with torch.no_grad(): 128 | accuracy.append(utils.accuracy(scores, labels_test, topk=(1,))[0]) 129 | 130 | loss = torch.stack(loss, dim=0) 131 | accuracy = torch.stack(accuracy, dim=0) 132 | 133 | return loss, accuracy 134 | -------------------------------------------------------------------------------- /pretrained_weight/pretrain_weight.txt: -------------------------------------------------------------------------------- 1 | Download weights (BEiT-1k-Face-55w.tar) from https://pan.baidu.com/s/1EOgJeE4Gb4TAaxvSkhK4lw (code:fr6r) and put it here. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | albumentations==1.3.0 3 | antlr4-python3-runtime==4.9.3 4 | cachetools==5.5.0 5 | certifi==2024.8.30 6 | charset-normalizer==3.4.0 7 | click==8.1.7 8 | contourpy==1.3.0 9 | cycler==0.12.1 10 | decord==0.6.0 11 | docker-pycreds==0.4.0 12 | dominate==2.9.1 13 | easydict==1.13 14 | einops==0.5.0 15 | filelock==3.16.1 16 | fonttools==4.54.1 17 | fsspec==2024.10.0 18 | ftfy==6.1.1 19 | gitdb==4.0.11 20 | GitPython==3.1.43 21 | google-auth==2.35.0 22 | google-auth-oauthlib==0.4.6 23 | grpcio==1.67.0 24 | huggingface-hub==0.26.0 25 | hydra-core==1.3.2 26 | idna==3.10 27 | imageio==2.36.0 28 | imgaug==0.4.0 29 | iopath==0.1.9 30 | joblib==1.4.2 31 | kiwisolver==1.4.7 32 | lazy_loader==0.4 33 | Markdown==3.7 34 | MarkupSafe==3.0.2 35 | matplotlib==3.9.2 36 | networkx==3.4.1 37 | ninja==1.11.1 38 | numpy==1.23.3 39 | oauthlib==3.2.2 40 | omegaconf==2.3.0 41 | opencv-python==4.10.0.84 42 | opencv-python-headless==4.10.0.84 43 | packaging==24.1 44 | pandas==1.5.1 45 | pathtools==0.1.2 46 | Pillow==9.2.0 47 | portalocker==2.10.1 48 | prefetch_generator==1.0.3 49 | promise==2.3 50 | protobuf==3.19.6 51 | psutil==6.1.0 52 | pyasn1==0.6.1 53 | pyasn1_modules==0.4.1 54 | pyparsing==3.2.0 55 | python-dateutil==2.9.0 56 | pytz==2024.2 57 | PyYAML==6.0 58 | qudida==0.0.4 59 | requests==2.32.3 60 | requests-oauthlib==2.0.0 61 | rsa==4.9 62 | safetensors==0.4.5 63 | scikit-image==0.24.0 64 | scikit-learn==1.5.2 65 | scipy==1.9.3 66 | sentry-sdk==2.17.0 67 | setproctitle==1.3.3 68 | shapely==2.0.6 69 | shortuuid==1.0.13 70 | six==1.16.0 71 | smmap==5.0.1 72 | tensorboard==2.10.1 73 | tensorboard-data-server==0.6.1 74 | tensorboard-plugin-wit==1.8.1 75 | tensorboardX==2.5.1 76 | threadpoolctl==3.5.0 77 | tifffile==2024.9.20 78 | timm==0.9.2 79 | tqdm==4.66.5 80 | typing_extensions==4.12.2 81 | urllib3==2.2.3 82 | wandb==0.13.5 83 | wcwidth==0.2.13 84 | Werkzeug==3.0.4 85 | -------------------------------------------------------------------------------- /test_img/Fake1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake1.png -------------------------------------------------------------------------------- /test_img/Fake1_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake1_aligned.png -------------------------------------------------------------------------------- /test_img/Fake2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake2.png -------------------------------------------------------------------------------- /test_img/Fake2_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake2_aligned.png -------------------------------------------------------------------------------- /test_img/Fake3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake3.png -------------------------------------------------------------------------------- /test_img/Fake3_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake3_aligned.png -------------------------------------------------------------------------------- /test_img/Fake4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake4.png -------------------------------------------------------------------------------- /test_img/Fake4_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake4_aligned.png -------------------------------------------------------------------------------- /test_img/Fake_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake_ori.png -------------------------------------------------------------------------------- /test_img/Fake_ori_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Fake_ori_aligned.png -------------------------------------------------------------------------------- /test_img/Real1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Real1.png -------------------------------------------------------------------------------- /test_img/Real1_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Real1_aligned.png -------------------------------------------------------------------------------- /test_img/Real2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Real2.png -------------------------------------------------------------------------------- /test_img/Real2_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenglab/FFDBackbone/0675f03d88a09b35210d797afb5e702318ad8424/test_img/Real2_aligned.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from utils import * 5 | import torch.nn as nn 6 | from common import losses 7 | from common.utils import * 8 | import torch.distributed as dist 9 | from models import * 10 | from datasets import * 11 | from shutil import copyfile 12 | from datasets.factory import get_dataloader 13 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..', '..')) 14 | from timm.optim import create_optimizer_v2, optimizer_kwargs 15 | from timm.scheduler import create_scheduler 16 | torch.autograd.set_detect_anomaly(True) 17 | from engine_finetune import train_one_epoch, test_one_epoch 18 | args = get_params() 19 | setup(args) 20 | init_exam_dir(args) 21 | ########################### 22 | # main logic for training # 23 | ########################### 24 | def main(): 25 | # use distributed training with nccl backend 26 | args.local_rank = int(os.environ.get('LOCAL_RANK', 0)) 27 | dist.init_process_group(backend='nccl', init_method="env://") 28 | torch.cuda.set_device(args.local_rank) 29 | args.world_size = dist.get_world_size() 30 | # set logger 31 | logger = get_logger(str(args.local_rank), console=args.local_rank==0, log_path=os.path.join(args.exam_dir, f'train_{args.local_rank}.log')) 32 | train_dataloader = get_dataloader(args, 'train') 33 | test_dataloader = get_dataloader(args, 'test') 34 | args.model.params.local_rank = args.local_rank 35 | model = eval(args.model.name)(**args.model.params) 36 | if args.local_rank == 0: 37 | file_name = os.path.join(args.exam_dir, 'Model.txt') 38 | with open(file_name, 'wt') as opt_file: 39 | opt_file.write(str(model)) 40 | opt_file.write('\n') 41 | config_file = args.config 42 | config_file_name = os.path.basename(config_file) 43 | copyfile(config_file, os.path.join(args.exam_dir, config_file_name)) 44 | model.cuda(args.local_rank) 45 | model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 46 | 47 | criterion = losses.__dict__[args.loss.name](**(args.loss.params if getattr(args.loss, "params", None) else {})).cuda(args.local_rank) 48 | # set optimizer 49 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args.optimizer.params)) 50 | print(optimizer) 51 | global_step = 1 52 | start_epoch = 1 53 | # resume model for a given checkpoint file 54 | if args.model.resume: 55 | logger.info(f'resume from {args.model.resume}') 56 | checkpoint = torch.load(args.model.resume, map_location='cpu') 57 | if 'state_dict' in checkpoint: 58 | sd = checkpoint['state_dict'] 59 | if (not getattr(args.model, 'only_resume_model', False)): 60 | if 'optimizer' in checkpoint: 61 | optimizer.load_state_dict(checkpoint['optimizer']) 62 | if 'global_step' in checkpoint: 63 | global_step = checkpoint['global_step'] 64 | if 'epoch' in checkpoint: 65 | start_epoch = checkpoint['epoch'] + 1 66 | else: 67 | sd = checkpoint 68 | model.load_state_dict(sd, strict=True) 69 | 70 | lr_scheduler, num_epochs = create_scheduler(args.scheduler, optimizer) 71 | args.train.max_epoches = num_epochs 72 | if lr_scheduler is not None and start_epoch > 1: 73 | lr_scheduler.step(start_epoch-1) 74 | # Training loops 75 | for epoch in range(start_epoch, num_epochs+1): 76 | train_dataloader.sampler.set_epoch(epoch) 77 | train_one_epoch(train_dataloader, model, criterion, optimizer, epoch, global_step, args, logger,lr_scheduler) 78 | global_step += len(train_dataloader) 79 | test_one_epoch(test_dataloader, model, criterion, optimizer, epoch, global_step, args, logger,lr_scheduler) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /train_dualbranch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from utils import * 5 | import torch.nn as nn 6 | from common import losses 7 | from common.utils import * 8 | import torch.distributed as dist 9 | from models import * 10 | from datasets import * 11 | from shutil import copyfile 12 | from datasets.factory import get_dataloader 13 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..', '..')) 14 | from timm.optim import create_optimizer_v2, optimizer_kwargs 15 | from timm.scheduler import create_scheduler 16 | torch.autograd.set_detect_anomaly(True) 17 | from engine_finetune import train_dual_one_epoch, test_dual_one_epoch 18 | args = get_params() 19 | setup(args) 20 | init_exam_dir(args) 21 | ########################### 22 | # main logic for training # 23 | ########################### 24 | def main(): 25 | # use distributed training with nccl backend 26 | args.local_rank = int(os.environ.get('LOCAL_RANK', 0)) 27 | dist.init_process_group(backend='nccl', init_method="env://") 28 | torch.cuda.set_device(args.local_rank) 29 | args.world_size = dist.get_world_size() 30 | # set logger 31 | logger = get_logger(str(args.local_rank), console=args.local_rank==0, log_path=os.path.join(args.exam_dir, f'train_{args.local_rank}.log')) 32 | train_dataloader = get_dataloader(args, 'train') 33 | test_dataloader = get_dataloader(args, 'test') 34 | args.model.params.local_rank = args.local_rank 35 | model = eval(args.model.name)(**args.model.params) 36 | if args.local_rank == 0: 37 | file_name = os.path.join(args.exam_dir, 'Model.txt') 38 | with open(file_name, 'wt') as opt_file: 39 | opt_file.write(str(model)) 40 | opt_file.write('\n') 41 | config_file = args.config 42 | config_file_name = os.path.basename(config_file) 43 | copyfile(config_file, os.path.join(args.exam_dir, config_file_name)) 44 | model.cuda(args.local_rank) 45 | model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 46 | 47 | criterion_e = losses.__dict__[args.loss.name](**(args.loss.params if getattr(args.loss, "params", None) else {})).cuda(args.local_rank) 48 | criterion_b = losses.__dict__[args.loss2.name](**(args.loss2.params if getattr(args.loss2, "params", None) else {})).cuda(args.local_rank) 49 | 50 | # set optimizer 51 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args.optimizer.params)) 52 | print(optimizer) 53 | global_step = 1 54 | start_epoch = 1 55 | # resume model for a given checkpoint file 56 | if args.model.resume: 57 | logger.info(f'resume from {args.model.resume}') 58 | checkpoint = torch.load(args.model.resume, map_location='cpu') 59 | if 'state_dict' in checkpoint: 60 | sd = checkpoint['state_dict'] 61 | if (not getattr(args.model, 'only_resume_model', False)): 62 | if 'optimizer' in checkpoint: 63 | optimizer.load_state_dict(checkpoint['optimizer']) 64 | if 'global_step' in checkpoint: 65 | global_step = checkpoint['global_step'] 66 | if 'epoch' in checkpoint: 67 | start_epoch = checkpoint['epoch'] + 1 68 | else: 69 | sd = checkpoint 70 | model.load_state_dict(sd, strict=True) 71 | 72 | lr_scheduler, num_epochs = create_scheduler(args.scheduler, optimizer) 73 | args.train.max_epoches = num_epochs 74 | if lr_scheduler is not None and start_epoch > 1: 75 | lr_scheduler.step(start_epoch-1) 76 | # Training loops 77 | for epoch in range(start_epoch, num_epochs+1): 78 | train_dataloader.sampler.set_epoch(epoch) 79 | train_dual_one_epoch(train_dataloader, model, criterion_b, criterion_e, optimizer, epoch, global_step, args, logger, lr_scheduler, num_epochs) 80 | global_step += len(train_dataloader) 81 | test_dual_one_epoch(test_dataloader, model, criterion_b, criterion_e, optimizer, epoch, global_step, args, logger,lr_scheduler, num_epochs) 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils for STIL 2 | """ 3 | from .lr_utils import lr_tuner 4 | from .metrics import compute_metrics,compute_image_metrics 5 | from .ckpt_process import * 6 | -------------------------------------------------------------------------------- /utils/aucloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from abc import abstractmethod 4 | 5 | """ 6 | Implementation of 7 | "Zhiyong Yang, Qianqian Xu, Shilong Bao, Xiaochun Cao and Qingming Huang. 8 | Learning with Multiclass AUC: Theory and Algorithms. T-PAMI, 2021." 9 | """ 10 | 11 | class AUCLoss(nn.Module): 12 | 13 | ''' 14 | args: 15 | num_classes: number of classes (mush include params) 16 | 17 | gamma: safe margin in pairwise loss (default=1.0) 18 | 19 | transform: manner to compute the multi-classes AUROC Metric, either 'ovo' or 'ova' (default as 'ovo' in our paper) 20 | 21 | ''' 22 | def __init__(self, 23 | num_classes, 24 | gamma=1, 25 | transform='ovo', *kwargs): 26 | super(AUCLoss, self).__init__() 27 | 28 | if transform != 'ovo' and transform != 'ova': 29 | raise Exception("type should be either ova or ovo") 30 | self.num_classes = num_classes 31 | self.gamma = gamma 32 | self.transform = transform 33 | 34 | if kwargs is not None: 35 | self.__dict__.update(kwargs) 36 | 37 | def _check_input(self, pred, target): 38 | # if not (pred.max() <= 1 and pred.min() >= 0): 39 | # print(pred) 40 | assert pred.max() <= 1 and pred.min() >= 0 41 | assert target.min() >= 0 42 | assert pred.shape[0] == target.shape[0] 43 | 44 | def forward(self, pred, target, **kwargs): 45 | ''' 46 | args: 47 | pred: score of samples residing in [0,1]. 48 | For examples, with respect to binary classification tasks, pred = torch.Sigmoid(...) 49 | o.w. pred = torch.Softmax(...) 50 | 51 | target: index of classes. In particular, w.r.t. binary classification tasks, we regard y=1 as pos. instances. 52 | 53 | ''' 54 | self._check_input(pred, target) 55 | 56 | if self.num_classes == 2: 57 | Y = target.float() 58 | numPos = torch.sum(Y.eq(1)) 59 | numNeg = torch.sum(Y.eq(0)) 60 | Di = 1.0 / numPos / numNeg 61 | return self.calLossPerCLass(pred.squeeze(1), Y, Di, numPos) 62 | else: 63 | if self.transform == 'ovo': 64 | factor = self.num_classes * (self.num_classes - 1) 65 | else: 66 | factor = 1 67 | 68 | Y = torch.stack( 69 | [target.eq(i).float() for i in range(self.num_classes)], 70 | 1).squeeze() 71 | 72 | N = Y.sum(0) 73 | D = 1 / N[target.squeeze().long()] 74 | 75 | loss = torch.Tensor([0.]).cuda() 76 | if self.transform == 'ova': 77 | ones_vec = torch.ones_like(D).cuda() 78 | 79 | for i in range(self.num_classes): 80 | if self.transform == 'ovo': 81 | Di = D / N[i] 82 | else: 83 | fac = torch.tensor([1.0]).cuda() / (N[i] * (N.sum() - N[i])) 84 | Di = fac * ones_vec 85 | Yi, predi = Y[:, i], pred[:, i] 86 | loss += self.calLossPerCLass(predi, Yi, Di, N[i]) 87 | 88 | return loss / factor 89 | 90 | def calLossPerCLass(self, predi, Yi, Di, Ni): 91 | 92 | return self.calLossPerCLassNaive(predi, Yi, Di, Ni) 93 | 94 | @abstractmethod 95 | def calLossPerCLassNaive(self, predi, Yi, Di, Ni): 96 | pass 97 | 98 | 99 | class SquareAUCLoss(AUCLoss): 100 | def __init__(self, num_classes, gamma=1, transform='ovo', **kwargs): 101 | super(SquareAUCLoss, self).__init__(num_classes, gamma, transform) 102 | 103 | # self.num_classes = num_classes 104 | # self.gamma = gamma 105 | 106 | if kwargs is not None: 107 | self.__dict__.update(kwargs) 108 | 109 | def calLossPerCLassNaive(self, predi, Yi, Di, Ni): 110 | diff = predi - self.gamma * Yi 111 | nD = Di.mul(1 - Yi) 112 | fac = (self.num_classes - 113 | 1) if self.transform == 'ovo' else torch.tensor(1.0).cuda() 114 | S = Ni * nD + (fac * Yi / Ni) 115 | diff = diff.reshape((-1, )) 116 | S = S.reshape((-1, )) 117 | A = diff.mul(S).dot(diff) 118 | nD= nD.reshape((-1, )) 119 | Yi= Yi.reshape((-1, )) 120 | B = diff.dot(nD) * Yi.dot(diff) 121 | return 0.5 * A - B 122 | 123 | class HingeAUCLoss(AUCLoss): 124 | def __init__(self, num_classes, gamma=1, transform='ovo', **kwargs): 125 | super(HingeAUCLoss, self).__init__(num_classes, gamma, transform) 126 | 127 | if kwargs is not None: 128 | self.__dict__.update(kwargs) 129 | 130 | def calLossPerCLassNaive(self, predi, Yi, Di, Ni): 131 | fac = 1 if self.transform == 'ova' else (self.num_classes - 1) 132 | delta1 = (fac / Ni) * Yi * predi 133 | delta2 = Di * (1 - Yi) * predi 134 | return fac * self.gamma - delta1.sum() + delta2.sum() 135 | 136 | 137 | class ExpAUCLoss(AUCLoss): 138 | def __init__(self, num_classes, gamma=1, transform='ovo', **kwargs): 139 | super(ExpAUCLoss, self).__init__(num_classes,gamma, transform) 140 | 141 | if kwargs is not None: 142 | self.__dict__.update(kwargs) 143 | 144 | def calLossPerCLassNaive(self, predi, Yi, Di, Ni): 145 | C1 = Yi * torch.exp(-self.gamma * predi) 146 | C2 = (1 - Yi) * torch.exp(self.gamma * predi) 147 | C2 = Di * C2 148 | return C1.sum() * C2.sum() -------------------------------------------------------------------------------- /utils/ckpt_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | 5 | 6 | def convert_weight(pth_path, checkpoint_save_dir): 7 | checkpoint = torch.load(pth_path, map_location='cpu') 8 | sd = checkpoint['state_dict'] 9 | 10 | for k in list(sd.keys()): 11 | if k.startswith('module.backbone_model2'): 12 | del sd[k] 13 | checkpoint_m = OrderedDict() 14 | checkpoint_m['state_dict'] = sd 15 | torch.save(checkpoint_m, checkpoint_save_dir) -------------------------------------------------------------------------------- /utils/lr_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | import omegaconf 5 | 6 | 7 | def to_list(inp): 8 | if inp is None: 9 | return None 10 | if isinstance(inp, omegaconf.listconfig.ListConfig): 11 | inp = omegaconf.OmegaConf.to_container(inp, resolve=True) 12 | return inp 13 | 14 | 15 | def lr_tuner(lr_init, optimizer, epoch_size, tune_dict, global_step=1, use_warmup=False, warmup_epochs=1, lr_min=1e-6): 16 | """A simple learning rate tuning strategy. 17 | Using tune_dict to tune the learning rate. 18 | e.g.: 19 | tune_dict: 20 | key: strategy name, value: strategy params 21 | (1) piecewise: 22 | { 23 | "decay_steps": [100, 200], 24 | "decay_epochs": [1, 2], 25 | "decay_rates": [0.1, 0.2], 26 | } 27 | (2) exponential: 28 | { 29 | "decay_step": 100, 30 | "decay_epoch": 1, 31 | "decay_rate": 0.9, 32 | "staircase": True, 33 | } 34 | 35 | Args: 36 | lr_init (float): Initial learning rate. 37 | optimizer (torch.optimizer): Torch optimizer. 38 | epoch_size (int): How many steps in one epoch. 39 | tune_dict (dict): The dict specifiying the learning rate tuning strategy. 40 | global_step (int, optional): Global step. Defaults to 1. 41 | use_warmup (bool, optional): Setting True to use warmup strategy. Defaults to False. 42 | warmup_epochs (int, optional): How many epoches to apply warmup. Defaults to 1. 43 | lr_min (float, optional): Minimal learning rate. Defaults to 1e-6. 44 | 45 | Returns: 46 | float: The tuned learning rate. 47 | """ 48 | if use_warmup and global_step <= epoch_size * warmup_epochs: 49 | lr_start = 0 50 | if global_step == 1: 51 | print(">>> Using warmup strategy!") 52 | # lr = global_step / epoch_size * warmup_epochs * lr_init 53 | alpha = (lr_init - lr_start) / (epoch_size * warmup_epochs) 54 | lr = global_step * alpha + lr_start 55 | else: 56 | if use_warmup: 57 | new_step = global_step - epoch_size * warmup_epochs 58 | else: 59 | new_step = global_step 60 | 61 | tune_dict = dict(tune_dict) 62 | decay_strategy_name = tune_dict['name'] 63 | 64 | if decay_strategy_name == "piecewise": 65 | decay_steps = to_list(tune_dict.get("decay_steps")) 66 | decay_epochs = to_list(tune_dict.get("decay_epochs")) 67 | decay_rates = to_list(tune_dict.get("decay_rates")) 68 | 69 | if decay_steps and decay_epochs: 70 | raise ValueError( 71 | "decay_steps and decay_epochs in tune_dict of lr_tuner are both set, only one of them can be set" 72 | ) 73 | if decay_epochs: 74 | decay_steps = [epoch_size * x for x in decay_epochs] 75 | 76 | decay_cnt = np.sum(new_step > np.asarray(decay_steps)) 77 | if decay_cnt == 0: 78 | decay_mult = 1.0 79 | else: 80 | decay_mult = decay_rates[decay_cnt - 1] 81 | 82 | lr = max(lr_init * decay_mult, lr_min) 83 | 84 | elif decay_strategy_name == "exponential": 85 | decay_step = tune_dict.get("decay_step") 86 | decay_epoch = tune_dict.get("decay_epoch") 87 | decay_rate = tune_dict.get("decay_rate") 88 | staircase = tune_dict.get("staircase") 89 | 90 | if decay_step and decay_epoch: 91 | raise ValueError( 92 | "decay_step and decay_epoch in tune_dict of lr_tuner are both set, only one of them can be set" 93 | ) 94 | if decay_epoch: 95 | decay_step = epoch_size * decay_epoch 96 | 97 | decay_index = new_step // decay_step if staircase else new_step / decay_step 98 | lr = max(lr_init * math.pow(decay_rate, decay_index), lr_min) 99 | elif decay_strategy_name == "linear": 100 | decay_epoch = tune_dict.get("decay_epochs") 101 | epochs = tune_dict.get("epochs") 102 | if decay_epoch: 103 | decay_step = epoch_size * decay_epoch 104 | alpha = 1.0 - max(0, global_step - decay_step) / float(epoch_size*epochs - decay_step) 105 | lr = lr_init*alpha 106 | else: 107 | raise NotImplementedError("decay_strategy_name {} is not supported".format(decay_strategy_name)) 108 | 109 | for param_group in optimizer.param_groups: 110 | param_group['lr'] = lr 111 | 112 | return lr 113 | 114 | from torch.optim.lr_scheduler import _LRScheduler 115 | 116 | class LinearDecayLR(_LRScheduler): 117 | def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1): 118 | self.start_decay=start_decay 119 | self.n_epoch=n_epoch 120 | super(LinearDecayLR, self).__init__(optimizer, last_epoch) 121 | 122 | def get_lr(self): 123 | last_epoch = self.last_epoch 124 | n_epoch=self.n_epoch 125 | b_lr=self.base_lrs[0] 126 | start_decay=self.start_decay 127 | if last_epoch>start_decay: 128 | lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay) 129 | else: 130 | lr=b_lr 131 | return [lr] 132 | 133 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def compute_metrics(model_outputs, labels): 6 | """ 7 | Compute the accuracy metrics. 8 | """ 9 | if len(model_outputs.shape) > 2: 10 | return compute_image_metrics(model_outputs, labels) 11 | real_probs = F.softmax(model_outputs, dim=1)[:, 0] 12 | bin_preds = (real_probs <= 0.5).int() 13 | bin_labels = (labels != 0).int() 14 | 15 | real_cnt = (bin_labels == 0).sum() 16 | fake_cnt = (bin_labels == 1).sum() 17 | 18 | acc = (bin_preds == bin_labels).float().mean() 19 | 20 | real_acc = (bin_preds == bin_labels)[torch.where(bin_labels == 0)].sum() / (real_cnt + 1e-12) 21 | fake_acc = (bin_preds == bin_labels)[torch.where(bin_labels == 1)].sum() / (fake_cnt + 1e-12) 22 | 23 | return acc.item(), real_acc.item(), fake_acc.item(), real_cnt.item(), fake_cnt.item() 24 | 25 | def compute_image_metrics(model_outputs, labels): 26 | """ 27 | Compute the accuracy metrics. 28 | """ 29 | real_probs = F.softmax(model_outputs, dim=2)[:,:, 0] 30 | real_probs = torch.mean(real_probs, dim=1) 31 | bin_preds = (real_probs <= 0.5).int() 32 | bin_labels = (labels != 0).int() 33 | 34 | real_cnt = (bin_labels == 0).sum() 35 | fake_cnt = (bin_labels == 1).sum() 36 | 37 | acc = (bin_preds == bin_labels).float().mean() 38 | 39 | real_acc = (bin_preds == bin_labels)[torch.where(bin_labels == 0)].sum() / (real_cnt + 1e-12) 40 | fake_acc = (bin_preds == bin_labels)[torch.where(bin_labels == 1)].sum() / (fake_cnt + 1e-12) 41 | 42 | return acc.item(), real_acc.item(), fake_acc.item(), real_cnt.item(), fake_cnt.item() 43 | 44 | -------------------------------------------------------------------------------- /utils/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | --------------------------------------------------------------------------------