├── README.md ├── config ├── test.yaml └── train.yaml ├── dataset ├── __init__.py ├── abstract_dataset.py └── albu.py ├── figures ├── archi.png └── structure.png ├── install.sh ├── logger.py ├── model ├── adapters │ └── adapter.py ├── attn.py ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── ds.py └── layer.py ├── test.py ├── train.py └── trainer ├── __init__.py ├── base_trainer.py ├── metrics ├── __init__.py ├── base_metrics_class.py ├── registry.py └── utils.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # Forensics Adapter: Adapting CLIP for Generalizable Face Forgery Detection (CVPR 2025) 2 | **👥 Authors: Xinjie Cui, [Yuezun Li](https://yuezunli.github.io/) (corresponding author), Ao Luo, Jiaran Zhou, Junyu Dong** 3 | 4 | 5 | 6 | 7 | ![Pipeline of the proposed Forensics Adapter. ](https://github.com/OUC-VAS/ForensicsAdapter/blob/main/figures/archi.png) 8 | 9 | 10 | --- 11 | 12 | ## 📚 Resources 13 | | **Section** | **Content** | 14 | |--------------------|-----------------------------------------------------------------------------| 15 | | 📄 **Paper** | [arXiv Preprint](https://arxiv.org/abs/2411.19715) | 16 | | ⚖️ **Model Weights** | [Google Drive](https://drive.google.com/file/d/1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm/view?usp=sharing) \| [Baidu Netdisk](https://pan.baidu.com/s/10bEjEvhUlm4WVhDM1vWtig?pwd=9pbc) | 17 | 18 | --- 19 | 20 | 21 | ## 📊 Benchmark Comparison 22 | ## 🖼️ Frame-Level Comparison 23 | 🏆 **Champion Method Alert**: Our approach establishes new state-of-the-art on all frame-level benchmarks! 24 | 25 | 26 | | Method | Venue | CDF-v1 | CDF-v2 | DFDC | DFDCP | DFD | Avg. 📈 | 27 | |----------------|------------|---------|---------|-------|--------|------|-------| 28 | | SPSL | CVPR'21 | 0.815 | 0.765 | 0.704 | 0.741 | 0.812| 0.767 | 29 | | SRM | CVPR'21 | 0.793 | 0.755 | 0.700 | 0.741 | 0.812| 0.760 | 30 | | Reece | CVPR'22 | 0.768 | 0.732 | 0.713 | 0.734 | 0.812| 0.752 | 31 | | SBI | CVPR'22 | - | 0.813 | - | 0.799 | 0.774| - | 32 | | UCF | ICCV'23 | 0.779 | 0.753 | 0.719 | 0.759 | 0.807| 0.763 | 33 | | ED | AAAI'24 | 0.818 | 0.864 | 0.721 | 0.851 | - | - | 34 | | LSDA | CVPR'24 | 0.867 | 0.830 | 0.736 | 0.815 | 0.880| 0.826 | 35 | | CFM | TIFS'24 | - | 0.828 | - | 0.758 | 0.915| - | 36 | | **Ours** | **CVPR'25** | 🥇 **0.914** | 🥇 **0.900** | 🥇 **0.843** | 🥇 **0.890** | 🥇 **0.933** | 🥇 **0.896** | 37 | 38 | ## 🎥 Video-Level Comparison 39 | 40 | 🏆 **Champion Method Alert**: Our approach achieves new SOTA performance across all video-level benchmarks! 41 | 42 | | Method | Venue | CDF-v2 | DFDC | DFDCP | 43 | |--------------------|------------|---------|-------|--------| 44 | | TALL | ICCV'23 | 0.908 | 0.768 | - | 45 | | AltFreezing | CVPR'23 | 0.895 | - | - | 46 | | SeeABLE | ICCV'23 | 0.873 | 0.759 | 0.863 | 47 | | IID | CVPR'23 | 0.838 | - | 0.812 | 48 | | TALL++ | IJCV'24 | 0.920 | 0.785 | - | 49 | | SAM | CVPR'24 | 0.890 | - | - | 50 | | SBI | CVPR'22 | 0.932 | 0.724 | 0.862 | 51 | | CADDM | CVPR'23 | 0.939 | 0.739 | - | 52 | | SFDG | CVPR'23 | 0.758 | 0.736 | - | 53 | | LAA-NET | CVPR'24 | 0.954 | - | 0.869 | 54 | | LSDA | CVPR'24 | 0.911 | 0.770 | - | 55 | | CFM | TIFS'24 | 0.897 | - | 0.802 | 56 | | **Ours** | **CVPR'25** | 🥇 **0.957** | 🥇 **0.872** | 🥇 **0.929** | 57 | 58 | --- 59 | 60 | 61 | ## 🚀 Start 62 | 63 | - [⏳ Environment Setup](#-environment-setup) 64 | - [📂 Dataset](#-dataset) 65 | - [🏋️ Training](#-training) 66 | - [🧪 Testing](#-testing) 67 | - [📝 Citation](#-citation) 68 | 69 | 70 | ## ⏳ Environment Setup 71 | Ensure your environment meets the following requirements: 72 | 73 | - 🐍 Python 3.9 74 | - 🔥 PyTorch 1.11 75 | - 🚀 CUDA 11.3 76 | 77 | Install dependencies: 78 | 79 | ```bash 80 | git clone https://github.com/OUC-VAS/ForensicsAdapter.git 81 | cd ForensicsAdapter 82 | conda create -n FA python=3.9 83 | conda activate FA 84 | sh install.sh 85 | ``` 86 | 87 | ## 📂 Dataset 88 | 89 | We use multiple datasets for training and evaluation: 90 | 91 | - FF++ 92 | - DFDC 93 | - DFDCP 94 | - DFD 95 | - CD1/CD2 96 | 97 | The dataset downloading and processing procedures can be referred to the implementation provided in [DeepfakeBench](https://github.com/SCLBD/DeepfakeBench) . 98 | 99 | 100 | ## 🏋️ Training 101 | Make sure to modify the relevant configurations in the train.yaml file before training. 102 | 103 | Start training with the following command: 104 | 105 | ```bash 106 | python train.py 107 | ``` 108 | 109 | ## 🧪 Testing 110 | Make sure to modify the relevant configurations in the test.yaml file before testing. 111 | 112 | To test the model, you can directly load our pre-trained weights and run a command like the following: 113 | 114 | ```bash 115 | python /data/cuixinjie/FA/test.py 116 | ``` 117 | ## 📝 Citation 118 | 119 | If our work is useful for your research, please cite it as follows: 120 | 121 | ```bibtex 122 | @InProceedings{Cui_2025_CVPR, 123 | author={Cui, Xinjie and Li, Yuezun and Luo, Ao and Zhou, Jiaran and Dong, Junyu}, 124 | title={Forensics Adapter: Adapting CLIP for Generalizable Face Forgery Detection}, 125 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 126 | year={2025} 127 | } 128 | -------------------------------------------------------------------------------- /config/test.yaml: -------------------------------------------------------------------------------- 1 | # log dir 2 | log_dir: /data/cuixinjie/FA/logs/test 3 | #log_dir: /data/cuixinjie/DsClip_V2/log 4 | lmdb: False 5 | mode: text 6 | dry_run: false 7 | model_name: 'ds' 8 | inter: 'none' 9 | task_target: " " 10 | save_avg: True 11 | #queue_size: 2048 12 | #random_k: 256 13 | loss_rate: "10 * loss1 + 200 * loss_mse + 20* loss_intra + 10 *loss_inter" 14 | vit_name: 'vit_tiny_patch16_224' 15 | train_set: 'ori' 16 | num_quires: 128 17 | fusion_map: {0: 0, 1: 1, 2: 8, 3: 15} 18 | clip_model_name: "ViT-L/14" 19 | #clip_model_name: "ViT-B/16" 20 | 21 | #clip_model_name: "data/cuixinjie/ViT-L-14.pt" 22 | 23 | device: 'cuda:0' 24 | mlp_dim: 256 25 | mlp_out_dim: 128 26 | head_num: 16 # for clip_model_name: "ViT-L/14" 27 | # dataset 28 | all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] 29 | train_dataset: [FaceForensics++] 30 | #test_dataset: [FaceForensics++] 31 | test_dataset: [Celeb-DF-v1 ,Celeb-DF-v2, DFDCP, DFDC, DeepFakeDetection] 32 | 33 | #test_dataset: [FaceForensics++] 34 | 35 | 36 | #test_dataset: [FF-F2F] 37 | 38 | #dataset_json_folder: '/media/ouc/新加卷/DS_CLIP_V2/dataset/dataset_json' 39 | dataset_json_folder: '/data/cuixinjie/dataset/dataset_json' 40 | 41 | compression: c23 # compression-level for videos 42 | train_batchSize: 20 # training batch size 43 | test_batchSize: 64 # test batch size 44 | workers: 8 # number of data loading workers 45 | frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing 46 | resolution: 256 # resolution of output image to network 47 | with_mask: true # whether to include mask information in the input 48 | with_xray: true 49 | with_patch_labels: true 50 | with_landmark: false # whether to include facial landmark information in the input 51 | # label settings 52 | label_dict: 53 | # DFD 54 | DFD_fake: 1 55 | DFD_real: 0 56 | # FF++ + FaceShifter(FF-real+FF-FH) 57 | FF-SH: 1 58 | FF-F2F: 1 59 | FF-DF: 1 60 | FF-FS: 1 61 | FF-NT: 1 62 | FF-FH: 1 63 | FF-real: 0 64 | # CelebDF 65 | CelebDFv1_real: 0 66 | CelebDFv1_fake: 1 67 | CelebDFv2_real: 0 68 | CelebDFv2_fake: 1 69 | # DFDCP 70 | DFDCP_Real: 0 71 | DFDCP_FakeA: 1 72 | DFDCP_FakeB: 1 73 | # DFDC 74 | DFDC_Fake: 1 75 | DFDC_Real: 0 76 | # DeeperForensics-1.0 77 | DF_fake: 1 78 | DF_real: 0 79 | # UADFV 80 | UADFV_Fake: 1 81 | UADFV_Real: 0 82 | 83 | 84 | 85 | # data augmentation 86 | use_data_augmentation: false # Add this flag to enable/disable data augmentation 87 | data_aug: 88 | flip_prob: 0.5 89 | rotate_prob: 0.5 90 | rotate_limit: [-10, 10] 91 | blur_prob: 0.5 92 | blur_limit: [3, 7] 93 | brightness_prob: 0.5 94 | brightness_limit: [-0.1, 0.1] 95 | contrast_limit: [-0.1, 0.1] 96 | quality_lower: 40 97 | quality_upper: 100 98 | 99 | # mean and std for normalization 100 | mean: [0.48145466, 0.4578275, 0.40821073] 101 | std: [0.26862954, 0.26130258, 0.27577711] 102 | 103 | # optimizer config 104 | optimizer: 105 | # choose between 'adam' and 'sgd' 106 | type: adam 107 | adam: 108 | lr: 0.0002 # learning rate 109 | beta1: 0.9 # beta1 for Adam optimizer 110 | beta2: 0.999 # beta2 for Adam optimizer 111 | eps: 0.00000001 # epsilon for Adam optimizer 112 | weight_decay: 0.0005 # weight decay for regularization 113 | amsgrad: false 114 | sgd: 115 | lr: 0.0002 # learning rate 116 | momentum: 0.9 # momentum for SGD optimizer 117 | weight_decay: 0.0005 # weight decay for regularization 118 | 119 | # training config 120 | lr_scheduler: null # learning rate scheduler 121 | nEpochs: 60 # number of epochs to train for 122 | start_epoch: 0 # manual epoch number (useful for restarts) 123 | save_epoch: 2 # interval epochs for saving models 124 | rec_iter: 100 # interval iterations for recording 125 | #logdir: /media/ouc/新加卷/DS_CLIP_V2/log # folder to output images and logs 126 | logdir: /data/cuixinjie/FA/log/Ablation # folder to output images and logs 127 | 128 | 129 | manualSeed: 1020 # manual seed for random number generation 130 | save_ckpt: true # whether to save checkpoiccnt 131 | save_feat: false 132 | 133 | # metric 134 | metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) 135 | 136 | # cuda 137 | ngpu: 1 # number of GPUs to use 138 | cuda: true # whether to use CUDA acceleration 139 | cudnn: true # whether to use CuDNN for convolution operations 140 | -------------------------------------------------------------------------------- /config/train.yaml: -------------------------------------------------------------------------------- 1 | # log dir 2 | log_dir: /data/cuixinjie/FA/logs/new 3 | #log_dir: /data/cuixinjie/DsClip_V2/log 4 | lmdb: False 5 | mode: train 6 | dry_run: false 7 | model_name: 'ds' 8 | inter: 'none' 9 | task_target: " " 10 | save_avg: True 11 | #queue_size: 2048 12 | #random_k: 256 13 | loss_rate: "10 * loss1 + 200 * loss_mse + 20* loss_intra + 10 *loss_inter" 14 | vit_name: 'vit_tiny_patch16_224' 15 | train_set: 'ori' 16 | num_quires: 128 17 | fusion_map: {1: 1, 2: 8, 3: 15} 18 | clip_model_name: "ViT-L/14" 19 | #clip_model_name: "ViT-B/16" 20 | 21 | #clip_model_name: "data/cuixinjie/ViT-L-14.pt" 22 | 23 | device: 'cuda:0' 24 | mlp_dim: 256 25 | mlp_out_dim: 128 26 | head_num: 16 # for clip_model_name: "ViT-L/14" 27 | # dataset 28 | all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] 29 | train_dataset: [FaceForensics++] 30 | #test_dataset: [FaceForensics++] 31 | test_dataset: [FaceForensics++] 32 | 33 | #test_dataset: [FaceForensics++] 34 | 35 | 36 | #test_dataset: [FF-F2F] 37 | 38 | #dataset_json_folder: '/media/ouc/新加卷/DS_CLIP_V2/dataset/dataset_json' 39 | dataset_json_folder: '/data/cuixinjie/dataset/dataset_json' 40 | 41 | compression: c23 # compression-level for videos 42 | train_batchSize: 20 # training batch size 43 | test_batchSize: 64 # test batch size 44 | workers: 8 # number of data loading workers 45 | frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing 46 | resolution: 256 # resolution of output image to network 47 | with_mask: true # whether to include mask information in the input 48 | with_xray: true 49 | with_patch_labels: true 50 | with_landmark: false # whether to include facial landmark information in the input 51 | # label settings 52 | label_dict: 53 | # DFD 54 | DFD_fake: 1 55 | DFD_real: 0 56 | # FF++ + FaceShifter(FF-real+FF-FH) 57 | FF-SH: 1 58 | FF-F2F: 1 59 | FF-DF: 1 60 | FF-FS: 1 61 | FF-NT: 1 62 | FF-FH: 1 63 | FF-real: 0 64 | # CelebDF 65 | CelebDFv1_real: 0 66 | CelebDFv1_fake: 1 67 | CelebDFv2_real: 0 68 | CelebDFv2_fake: 1 69 | # DFDCP 70 | DFDCP_Real: 0 71 | DFDCP_FakeA: 1 72 | DFDCP_FakeB: 1 73 | # DFDC 74 | DFDC_Fake: 1 75 | DFDC_Real: 0 76 | # DeeperForensics-1.0 77 | DF_fake: 1 78 | DF_real: 0 79 | # UADFV 80 | UADFV_Fake: 1 81 | UADFV_Real: 0 82 | 83 | 84 | 85 | # data augmentation 86 | use_data_augmentation: false # Add this flag to enable/disable data augmentation 87 | data_aug: 88 | flip_prob: 0.5 89 | rotate_prob: 0.5 90 | rotate_limit: [-10, 10] 91 | blur_prob: 0.5 92 | blur_limit: [3, 7] 93 | brightness_prob: 0.5 94 | brightness_limit: [-0.1, 0.1] 95 | contrast_limit: [-0.1, 0.1] 96 | quality_lower: 40 97 | quality_upper: 100 98 | 99 | # mean and std for normalization 100 | mean: [0.48145466, 0.4578275, 0.40821073] 101 | std: [0.26862954, 0.26130258, 0.27577711] 102 | 103 | # optimizer config 104 | optimizer: 105 | # choose between 'adam' and 'sgd' 106 | type: adam 107 | adam: 108 | lr: 0.0002 # learning rate 109 | beta1: 0.9 # beta1 for Adam optimizer 110 | beta2: 0.999 # beta2 for Adam optimizer 111 | eps: 0.00000001 # epsilon for Adam optimizer 112 | weight_decay: 0.0005 # weight decay for regularization 113 | amsgrad: false 114 | sgd: 115 | lr: 0.0002 # learning rate 116 | momentum: 0.9 # momentum for SGD optimizer 117 | weight_decay: 0.0005 # weight decay for regularization 118 | 119 | # training config 120 | lr_scheduler: null # learning rate scheduler 121 | nEpochs: 60 # number of epochs to train for 122 | start_epoch: 0 # manual epoch number (useful for restarts) 123 | save_epoch: 2 # interval epochs for saving models 124 | rec_iter: 100 # interval iterations for recording 125 | #logdir: /media/ouc/新加卷/DS_CLIP_V2/log # folder to output images and logs 126 | 127 | manualSeed: 1020 # manual seed for random number generation 128 | save_ckpt: true # whether to save checkpoiccnt 129 | save_feat: false 130 | 131 | # metric 132 | metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) 133 | 134 | # cuda 135 | ngpu: 1 # number of GPUs to use 136 | cuda: true # whether to use CUDA acceleration 137 | cudnn: true # whether to use CuDNN for convolution operations 138 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .albu import IsotropicResize -------------------------------------------------------------------------------- /dataset/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import os 6 | import yaml 7 | import json 8 | import numpy as np 9 | from copy import deepcopy 10 | import cv2 11 | import random 12 | from PIL import Image 13 | import torch 14 | from torch.utils import data 15 | from torchvision import transforms as T 16 | import albumentations as A 17 | from dataset.albu import IsotropicResize 18 | 19 | 20 | def get_boundary(mask): 21 | if len(mask.shape) == 3: 22 | mask = mask[:, :, 0] 23 | mask = cv2.GaussianBlur(mask, (3, 3), 0) 24 | boundary = mask / 255. 25 | boundary = 4 * boundary * (1. - boundary) 26 | return boundary 27 | 28 | 29 | def split_images_by_patch(mask_image, patch_size, mode='normal', need_boundary=False): 30 | if mode == 'resize': 31 | mask_image = cv2.resize(mask_image, (224, 224)) 32 | 33 | _, b_image = cv2.threshold(mask_image, 40, 255, cv2.THRESH_BINARY) 34 | height, width = b_image.shape 35 | 36 | num_patches_h = height // patch_size 37 | num_patches_w = width // patch_size 38 | labels = [] 39 | if_boundaries = [] 40 | 41 | for i in range(num_patches_h): 42 | for j in range(num_patches_w): 43 | start_y = i * patch_size 44 | start_x = j * patch_size 45 | patch = b_image[start_y:start_y + patch_size, start_x:start_x + patch_size] 46 | white_pixels = cv2.findNonZero(patch).shape[0] if cv2.findNonZero( 47 | patch) is not None else 0 48 | black_pixels = cv2.findNonZero(cv2.bitwise_not(patch)).shape[0] if cv2.findNonZero( 49 | cv2.bitwise_not(patch)) is not None else 0 50 | total_pixels = black_pixels + white_pixels 51 | label = 1 if (white_pixels / total_pixels > 0.1 ) else 0 # 0 real 1 fake 52 | 53 | labels.append(label) 54 | if need_boundary: 55 | if_boundary = 1 if (white_pixels / total_pixels > 0.1 and white_pixels != total_pixels) else 0 # 0 real 1 fake 56 | if_boundaries.append(if_boundary) 57 | if need_boundary: 58 | return labels, if_boundaries 59 | else: 60 | return labels 61 | 62 | 63 | class DeepfakeAbstractBaseDataset(data.Dataset): 64 | """ 65 | Abstract base class for all deepfake datasets. 66 | """ 67 | 68 | def __init__(self, config=None, mode='train'): 69 | """Initializes the dataset object. 70 | 71 | Args: 72 | config (dict): A dictionary containing configuration parameters. 73 | mode (str): A string indicating the mode (train or test). 74 | 75 | Raises: 76 | NotImplementedError: If mode is not train or test. 77 | """ 78 | 79 | # Set the configuration and mode 80 | self.config = config 81 | self.mode = mode 82 | self.compression = config['compression'] 83 | self.frame_num = config['frame_num'][mode] 84 | 85 | # Dataset dictionary 86 | self.image_list = [] 87 | self.label_list = [] 88 | 89 | # Set the dataset dictionary based on the mode 90 | if mode == 'train': 91 | dataset_list = config['train_dataset'] 92 | # Training data should be collected together for training 93 | image_list, label_list = [], [] 94 | for one_data in dataset_list: 95 | tmp_image, tmp_label = self.collect_img_and_label_for_one_dataset(one_data) 96 | image_list.extend(tmp_image) 97 | label_list.extend(tmp_label) 98 | elif mode == 'test': 99 | one_data = config['test_dataset'] 100 | # Test dataset should be evaluated separately. So collect only one dataset each time 101 | image_list, label_list = self.collect_img_and_label_for_one_dataset(one_data) 102 | else: 103 | raise NotImplementedError('Only train and test modes are supported.') 104 | 105 | assert len(image_list) != 0 and len(label_list) != 0, f"Collect nothing for {mode} mode!" 106 | self.image_list, self.label_list = image_list, label_list 107 | 108 | # Create a dictionary containing the image and label lists 109 | self.data_dict = { 110 | 'image': self.image_list, 111 | 'label': self.label_list, 112 | } 113 | 114 | self.transform = self.init_data_aug_method() 115 | 116 | def init_data_aug_method(self): 117 | trans = A.Compose([ 118 | A.HorizontalFlip(p=self.config['data_aug']['flip_prob']), 119 | A.Rotate(limit=self.config['data_aug']['rotate_limit'], p=self.config['data_aug']['rotate_prob']), 120 | A.GaussianBlur(blur_limit=self.config['data_aug']['blur_limit'], p=self.config['data_aug']['blur_prob']), 121 | A.OneOf([ 122 | IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, 123 | interpolation_up=cv2.INTER_CUBIC), 124 | IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, 125 | interpolation_up=cv2.INTER_LINEAR), 126 | IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_LINEAR, 127 | interpolation_up=cv2.INTER_LINEAR), 128 | ], p=1), 129 | A.OneOf([ 130 | A.RandomBrightnessContrast(brightness_limit=self.config['data_aug']['brightness_limit'], 131 | contrast_limit=self.config['data_aug']['contrast_limit']), 132 | A.FancyPCA(), 133 | A.HueSaturationValue() 134 | ], p=0.5), 135 | A.ImageCompression(quality_lower=self.config['data_aug']['quality_lower'], 136 | quality_upper=self.config['data_aug']['quality_upper'], p=0.5) 137 | ], 138 | keypoint_params=A.KeypointParams(format='xy') if self.config['with_landmark'] else None 139 | ) 140 | return trans 141 | 142 | def collect_img_and_label_for_one_dataset(self, dataset_name: str): 143 | """Collects image and label lists. 144 | 145 | Args: 146 | dataset_name (str): A list containing one dataset information. e.g., 'FF-F2F' 147 | 148 | Returns: 149 | list: A list of image paths. 150 | list: A list of labels. 151 | 152 | Raises: 153 | ValueError: If image paths or labels are not found. 154 | NotImplementedError: If the dataset is not implemented yet. 155 | """ 156 | # Initialize the label and frame path lists 157 | label_list = [] 158 | frame_path_list = [] 159 | 160 | # Try to get the dataset information from the JSON file 161 | try: 162 | with open(os.path.join(self.config['dataset_json_folder'], dataset_name + '.json'), 'r') as f: 163 | dataset_info = json.load(f) 164 | except Exception as e: 165 | print(e) 166 | raise ValueError(f'dataset {dataset_name} not exist!') 167 | 168 | # If JSON file exists, do the following data collection 169 | # FIXME: ugly, need to be modified here. 170 | cp = None 171 | if dataset_name == 'FaceForensics++_c40': 172 | dataset_name = 'FaceForensics++' 173 | cp = 'c40' 174 | elif dataset_name == 'FF-DF_c40': 175 | dataset_name = 'FF-DF' 176 | cp = 'c40' 177 | elif dataset_name == 'FF-F2F_c40': 178 | dataset_name = 'FF-F2F' 179 | cp = 'c40' 180 | elif dataset_name == 'FF-FS_c40': 181 | dataset_name = 'FF-FS' 182 | cp = 'c40' 183 | elif dataset_name == 'FF-NT_c40': 184 | dataset_name = 'FF-NT' 185 | cp = 'c40' 186 | # Get the information for the current dataset 187 | for label in dataset_info[dataset_name]: 188 | sub_dataset_info = dataset_info[dataset_name][label][self.mode] 189 | # Special case for FaceForensics++ and DeepFakeDetection, choose the compression type 190 | if cp == None and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++', 191 | 'DeepFakeDetection', 'FaceShifter']: 192 | sub_dataset_info = sub_dataset_info[self.compression] 193 | elif cp == 'c40' and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++', 194 | 'DeepFakeDetection', 'FaceShifter']: 195 | sub_dataset_info = sub_dataset_info['c40'] 196 | # Iterate over the videos in the dataset 197 | # {"001": {"label": "FF-real", "frames": ["../dataset/FaceFo 198 | # video_name: '001' video_info:{"label": "FF-real", "frames": ["../dataset/FaceFo} 199 | for video_name, video_info in sub_dataset_info.items(): 200 | # Get the label and frame paths for the current video 201 | if video_info['label'] not in self.config['label_dict']: 202 | raise ValueError(f'Label {video_info["label"]} is not found in the configuration file.') 203 | label = self.config['label_dict'][video_info['label']] 204 | frame_paths = video_info['frames'] 205 | 206 | # Select self.frame_num frames evenly distributed throughout the video 207 | total_frames = len(frame_paths) 208 | 209 | if self.frame_num < total_frames: 210 | step = total_frames // self.frame_num 211 | selected_frames = [frame_paths[i] for i in range(0, total_frames, step)][:self.frame_num] 212 | # Append the label and frame paths to the lists according the number of frames 213 | label_list.extend([label] * len(selected_frames)) 214 | frame_path_list.extend(selected_frames) 215 | else: 216 | label_list.extend([label] * total_frames) 217 | frame_path_list.extend(frame_paths) 218 | 219 | # Shuffle the label and frame path lists in the same order 220 | shuffled = list(zip(label_list, frame_path_list)) 221 | random.shuffle(shuffled) 222 | label_list, frame_path_list = zip(*shuffled) 223 | 224 | return frame_path_list, label_list 225 | 226 | def load_rgb(self, file_path): 227 | """ 228 | Load an RGB image from a file path and resize it to a specified resolution. 229 | 230 | Args: 231 | file_path: A string indicating the path to the image file. 232 | 233 | Returns: 234 | An Image object containing the loaded and resized image. 235 | 236 | Raises: 237 | ValueError: If the loaded image is None. 238 | """ 239 | size = self.config['resolution'] 240 | assert os.path.exists(file_path), f"{file_path} does not exist" 241 | img = cv2.imread(file_path) 242 | if img is None: 243 | raise ValueError('Loaded image is None: {}'.format(file_path)) 244 | 245 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 246 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) 247 | return Image.fromarray(np.array(img, dtype=np.uint8)) 248 | 249 | def load_mask(self, file_path): 250 | """ 251 | Load a binary mask image from a file path and resize it to a specified resolution. 252 | 253 | Args: 254 | file_path: A string indicating the path to the mask file. 255 | 256 | Returns: 257 | A numpy array containing the loaded and resized mask. 258 | 259 | Raises: 260 | None. 261 | """ 262 | size = self.config['resolution'] 263 | if file_path is None: 264 | return np.zeros((size, size, 1)) 265 | if os.path.exists(file_path): 266 | mask = cv2.imread(file_path, 0) 267 | if mask is None: 268 | mask = np.zeros((size, size)) 269 | mask = cv2.resize(mask, (size, size)) 270 | mask = np.expand_dims(mask, axis=2) 271 | return np.float32(mask) 272 | else: 273 | return np.zeros((size, size, 1)) 274 | 275 | def load_landmark(self, file_path): 276 | """ 277 | Load 2D facial landmarks from a file path. 278 | 279 | Args: 280 | file_path: A string indicating the path to the landmark file. 281 | 282 | Returns: 283 | A numpy array containing the loaded landmarks. 284 | 285 | Raises: 286 | None. 287 | """ 288 | if file_path is None: 289 | return np.zeros((81, 2)) 290 | if os.path.exists(file_path): 291 | landmark = np.load(file_path) 292 | return np.float32(landmark) 293 | else: 294 | return np.zeros((81, 2)) 295 | 296 | def to_tensor(self, img): 297 | """ 298 | Convert an image to a PyTorch tensor. 299 | """ 300 | return T.ToTensor()(img) 301 | 302 | def normalize(self, img): 303 | """ 304 | Normalize an image. 305 | """ 306 | mean = self.config['mean'] 307 | std = self.config['std'] 308 | normalize = T.Normalize(mean=mean, std=std) 309 | return normalize(img) 310 | 311 | def data_aug(self, img, landmark=None, mask=None): 312 | """ 313 | Apply data augmentation to an image, landmark, and mask. 314 | 315 | Args: 316 | img: An Image object containing the image to be augmented. 317 | landmark: A numpy array containing the 2D facial landmarks to be augmented. 318 | mask: A numpy array containing the binary mask to be augmented. 319 | 320 | Returns: 321 | The augmented image, landmark, and mask. 322 | """ 323 | 324 | # Create a dictionary of arguments 325 | kwargs = {'image': img} 326 | 327 | # Check if the landmark and mask are not None 328 | if landmark is not None: 329 | kwargs['keypoints'] = landmark 330 | kwargs['keypoint_params'] = A.KeypointParams(format='xy') 331 | if mask is not None: 332 | kwargs['mask'] = mask 333 | 334 | # Apply data augmentation 335 | transformed = self.transform(**kwargs) 336 | 337 | # Get the augmented image, landmark, and mask 338 | augmented_img = transformed['image'] 339 | augmented_landmark = transformed.get('keypoints') 340 | augmented_mask = transformed.get('mask') 341 | 342 | # Convert the augmented landmark to a numpy array 343 | if augmented_landmark is not None: 344 | augmented_landmark = np.array(augmented_landmark) 345 | 346 | return augmented_img, augmented_landmark, augmented_mask 347 | 348 | def __getitem__(self, index): 349 | """ 350 | Returns the data point at the given index. 351 | 352 | Args: 353 | index (int): The index of the data point. 354 | 355 | Returns: 356 | A tuple containing the image tensor, the label tensor, the landmark tensor, 357 | and the mask tensor. 358 | """ 359 | # Get the image paths and label 360 | image_path = self.data_dict['image'][index] 361 | label = self.data_dict['label'][index] 362 | 363 | # Get the mask and landmark paths 364 | mask_path = image_path.replace('frames', 'masks') # Use .png for mask 365 | landmark_path = image_path.replace('frames', 'landmarks').replace('.png', '.npy') # Use .npy for landmark 366 | 367 | # Load the image 368 | try: 369 | image = self.load_rgb(image_path) 370 | except Exception as e: 371 | # Skip this image and return the first one 372 | print(f"Error loading image at index {index}: {e}") 373 | return self.__getitem__(0) 374 | image = np.array(image) # Convert to numpy array for data augmentation 375 | 376 | # Load mask and landmark (if needed) 377 | if self.config['with_mask']: 378 | mask = self.load_mask(mask_path) 379 | else: 380 | mask = None 381 | if self.config['with_landmark']: 382 | landmarks = self.load_landmark(landmark_path) 383 | else: 384 | landmarks = None 385 | 386 | # Do Data Augmentation 387 | if self.mode == 'train' and self.config['use_data_augmentation']: 388 | image_trans, landmarks_trans, mask_trans = self.data_aug(image, landmarks, mask) 389 | else: 390 | image_trans, landmarks_trans, mask_trans, xray_trans, patch_label_trans,clip_patch_label_trans = deepcopy(image), deepcopy( 391 | landmarks), deepcopy(mask), deepcopy(mask), deepcopy(mask), deepcopy(mask) 392 | 393 | # To tensor and normalize 394 | image_trans = self.normalize(self.to_tensor(image_trans)) 395 | #image_trans = self.to_tensor(image_trans) 396 | if self.config['with_landmark']: 397 | landmarks_trans = torch.from_numpy(landmarks) 398 | 399 | if self.config['with_xray']: 400 | boundary = get_boundary(xray_trans) 401 | #boundary = get_mask(xray_trans) 402 | boundary = torch.from_numpy(boundary) 403 | boundary = boundary.unsqueeze(2).permute(2, 0, 1) 404 | 405 | if label == 0: # real 406 | xray_trans = torch.zeros_like(boundary) 407 | else: # fake 408 | xray_trans = boundary 409 | else: 410 | xray_trans = None 411 | 412 | if self.config['with_mask']: 413 | mask_trans = torch.from_numpy(mask_trans) 414 | if self.config['with_patch_labels']: 415 | patch_label_trans, if_boundaries = split_images_by_patch(patch_label_trans.squeeze(), 16, need_boundary=True) 416 | patch_label_trans = torch.tensor(patch_label_trans) 417 | if_boundaries_trans = torch.tensor(if_boundaries) 418 | 419 | clip_patch_label_trans = split_images_by_patch(clip_patch_label_trans.squeeze(), 14, mode='resize') 420 | clip_patch_label_trans = torch.tensor(clip_patch_label_trans) 421 | 422 | else: 423 | patch_label_trans = None 424 | clip_patch_label_trans =None 425 | if_boundaries_trans = None 426 | 427 | 428 | return image_trans, label, landmarks_trans, mask_trans, xray_trans, patch_label_trans,clip_patch_label_trans, if_boundaries_trans 429 | 430 | @staticmethod 431 | def collate_fn(batch): 432 | """ 433 | Collate a batch of data points. 434 | 435 | Args: 436 | batch (list): A list of tuples containing the image tensor, the label tensor, 437 | the landmark tensor, and the mask tensor. 438 | 439 | Returns: 440 | A tuple containing the image tensor, the label tensor, the landmark tensor, 441 | and the mask tensor. 442 | """ 443 | # Separate the image, label, landmark, and mask tensors 444 | images, labels, landmarks, masks, xrays, patch_labels, clip_patch_labels, if_boundaries = zip(*batch) 445 | 446 | # Stack the image, label, landmark, and mask tensors 447 | images = torch.stack(images, dim=0) 448 | labels = torch.LongTensor(labels) 449 | 450 | # Special case for landmarks and masks if they are None 451 | if landmarks[0] is not None: 452 | landmarks = torch.stack(landmarks, dim=0) 453 | else: 454 | landmarks = None 455 | 456 | if masks[0] is not None: 457 | masks = torch.stack(masks, dim=0) 458 | else: 459 | masks = None 460 | 461 | if xrays[0] is not None: 462 | xrays = torch.stack(xrays, dim=0) 463 | else: 464 | xrays = None 465 | 466 | if patch_labels[0] is not None: 467 | patch_labels = torch.stack(patch_labels, dim=0) 468 | clip_patch_labels = torch.stack(clip_patch_labels, dim=0) 469 | if_boundaries = torch.stack(if_boundaries, dim=0) 470 | 471 | else: 472 | patch_labels = None 473 | clip_patch_labels = None 474 | if_boundaries = None 475 | 476 | # Create a dictionary of the tensors 477 | data_dict = {} 478 | data_dict['image'] = images 479 | data_dict['label'] = labels 480 | data_dict['landmark'] = landmarks 481 | data_dict['mask'] = masks 482 | data_dict['xray'] = xrays 483 | data_dict['patch_label'] = patch_labels 484 | data_dict['clip_patch_label'] = clip_patch_labels 485 | data_dict['if_boundary'] = if_boundaries 486 | 487 | 488 | return data_dict 489 | 490 | def __len__(self): 491 | """ 492 | Return the length of the dataset. 493 | 494 | Args: 495 | None. 496 | 497 | Returns: 498 | An integer indicating the length of the dataset. 499 | 500 | Raises: 501 | AssertionError: If the number of images and labels in the dataset are not equal. 502 | """ 503 | assert len(self.image_list) == len(self.label_list), 'Number of images and labels are not equal' 504 | return len(self.image_list) 505 | 506 | 507 | if __name__ == "__main__": 508 | 509 | with open('', 'r') as f: 510 | config = yaml.safe_load(f) 511 | train_set = DeepfakeAbstractBaseDataset( 512 | config=config, 513 | mode='train', 514 | ) 515 | train_data_loader = \ 516 | torch.utils.data.DataLoader( 517 | dataset=train_set, 518 | batch_size=config['train_batchSize'], 519 | shuffle=True, 520 | num_workers=int(config['workers']), 521 | collate_fn=train_set.collate_fn, 522 | ) 523 | from tqdm import tqdm 524 | 525 | import matplotlib.pyplot as plt 526 | 527 | torch.set_printoptions(threshold=torch.inf) 528 | batch = next(iter(train_data_loader)) 529 | masks = batch['mask'] 530 | for index, mask in enumerate(masks): 531 | mask = mask.squeeze().numpy() 532 | if batch['label'][index] == 1: 533 | plt.imshow(mask, cmap='gray') 534 | plt.axis('off') 535 | plt.show() 536 | -------------------------------------------------------------------------------- /dataset/albu.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | from albumentations import DualTransform, ImageOnlyTransform 6 | from albumentations.augmentations.crops.functional import crop 7 | 8 | 9 | def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC): 10 | h, w = img.shape[:2] 11 | if max(w, h) == size: 12 | return img 13 | if w > h: 14 | scale = size / w 15 | h = h * scale 16 | w = size 17 | else: 18 | scale = size / h 19 | w = w * scale 20 | h = size 21 | interpolation = interpolation_up if scale > 1 else interpolation_down 22 | resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation) 23 | return resized 24 | 25 | 26 | class IsotropicResize(DualTransform): 27 | def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, 28 | always_apply=False, p=1): 29 | super(IsotropicResize, self).__init__(always_apply, p) 30 | self.max_side = max_side 31 | self.interpolation_down = interpolation_down 32 | self.interpolation_up = interpolation_up 33 | 34 | def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params): 35 | return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down, 36 | interpolation_up=interpolation_up) 37 | 38 | def apply_to_mask(self, img, **params): 39 | return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params) 40 | 41 | def get_transform_init_args_names(self): 42 | return ("max_side", "interpolation_down", "interpolation_up") 43 | 44 | 45 | class Resize4xAndBack(ImageOnlyTransform): 46 | def __init__(self, always_apply=False, p=0.5): 47 | super(Resize4xAndBack, self).__init__(always_apply, p) 48 | 49 | def apply(self, img, **params): 50 | h, w = img.shape[:2] 51 | scale = random.choice([2, 4]) 52 | img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA) 53 | img = cv2.resize(img, (w, h), 54 | interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST])) 55 | return img 56 | 57 | 58 | class RandomSizedCropNonEmptyMaskIfExists(DualTransform): 59 | 60 | def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5): 61 | super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p) 62 | 63 | self.min_max_height = min_max_height 64 | self.w2h_ratio = w2h_ratio 65 | 66 | def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params): 67 | cropped = crop(img, x_min, y_min, x_max, y_max) 68 | return cropped 69 | 70 | @property 71 | def targets_as_params(self): 72 | return ["mask"] 73 | 74 | def get_params_dependent_on_targets(self, params): 75 | mask = params["mask"] 76 | mask_height, mask_width = mask.shape[:2] 77 | crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1])) 78 | w2h_ratio = random.uniform(*self.w2h_ratio) 79 | crop_width = min(int(crop_height * w2h_ratio), mask_width - 1) 80 | if mask.sum() == 0: 81 | x_min = random.randint(0, mask_width - crop_width + 1) 82 | y_min = random.randint(0, mask_height - crop_height + 1) 83 | else: 84 | mask = mask.sum(axis=-1) if mask.ndim == 3 else mask 85 | non_zero_yx = np.argwhere(mask) 86 | y, x = random.choice(non_zero_yx) 87 | x_min = x - random.randint(0, crop_width - 1) 88 | y_min = y - random.randint(0, crop_height - 1) 89 | x_min = np.clip(x_min, 0, mask_width - crop_width) 90 | y_min = np.clip(y_min, 0, mask_height - crop_height) 91 | 92 | x_max = x_min + crop_height 93 | y_max = y_min + crop_width 94 | y_max = min(mask_height, y_max) 95 | x_max = min(mask_width, x_max) 96 | return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max} 97 | 98 | def get_transform_init_args_names(self): 99 | return "min_max_height", "height", "width", "w2h_ratio" -------------------------------------------------------------------------------- /figures/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OUC-VAS/ForensicsAdapter/b510b1db9fa57531e269ff4cfb6654444607592a/figures/archi.png -------------------------------------------------------------------------------- /figures/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OUC-VAS/ForensicsAdapter/b510b1db9fa57531e269ff4cfb6654444607592a/figures/structure.png -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install numpy==1.21.5 4 | pip install pandas==1.4.2 5 | pip install Pillow==9.0.1 6 | pip install dlib==19.24.0 7 | pip install imageio==2.9.0 8 | pip install imgaug==0.4.0 9 | pip install tqdm==4.61.0 10 | pip install scipy==1.7.3 11 | pip install seaborn==0.11.2 12 | pip install pyyaml==6.0 13 | pip install imutils==0.5.4 14 | pip install opencv-python==4.6.0.66 15 | pip install scikit-image==0.19.2 16 | pip install scikit-learn==1.0.2 17 | pip install albumentations==1.1.0 18 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 19 | pip install efficientnet-pytorch==0.7.1 20 | pip install timm==0.6.12 21 | pip install segmentation-models-pytorch==0.3.2 22 | pip install torchtoolbox==0.1.8.2 23 | pip install tensorboard==2.10.1 24 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch.distributed as dist 5 | 6 | class RankFilter(logging.Filter): 7 | def __init__(self, rank): 8 | super().__init__() 9 | self.rank = rank 10 | 11 | def filter(self, record): 12 | return dist.get_rank() == self.rank 13 | 14 | def create_logger(log_path): 15 | # Create log path 16 | if os.path.isdir(os.path.dirname(log_path)): 17 | os.makedirs(os.path.dirname(log_path), exist_ok=True) 18 | 19 | # Create logger object 20 | logger = logging.getLogger() 21 | logger.setLevel(logging.INFO) 22 | # Create file handler and set the formatter 23 | fh = logging.FileHandler(log_path) 24 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 25 | fh.setFormatter(formatter) 26 | 27 | # Add the file handler to the logger 28 | logger.addHandler(fh) 29 | 30 | # Add a stream handler to print to console 31 | sh = logging.StreamHandler() 32 | sh.setLevel(logging.INFO) # Set logging level for stream handler 33 | sh.setFormatter(formatter) 34 | logger.addHandler(sh) 35 | 36 | return logger -------------------------------------------------------------------------------- /model/adapters/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm import create_model 4 | from torch.nn import functional as F 5 | 6 | from model.layer import Fusion, MLP, PatchEmbed, VT_LN 7 | from functools import partial 8 | 9 | 10 | class Mask_Decoder(nn.Module): 11 | def __init__(self, in_dim, mlp_dim=512, out_dim=256, mlp_num_layers=3, head_num=16): 12 | super().__init__() 13 | self.in_dim = in_dim 14 | self.mlp_dim = mlp_dim 15 | self.out_dim = out_dim 16 | self.head_num = head_num 17 | dense_affine_func = partial(nn.Conv2d, kernel_size=1) 18 | self.query_mlp = MLP(in_dim, mlp_dim, out_dim, mlp_num_layers) # L R L R L 19 | self.xray_mlp = MLP(in_dim, mlp_dim, out_dim, mlp_num_layers, affine_func=dense_affine_func) 20 | self.attn_mlp = MLP(in_dim, mlp_dim, out_dim * self.head_num, mlp_num_layers, affine_func=dense_affine_func) 21 | self.bias_scaling = nn.Linear(1, 1) 22 | 23 | def forward(self, query, x): 24 | # query (N,QL,D) x (N D H W) 25 | query = self.query_mlp(query) 26 | xray = self.xray_mlp(x) 27 | attn = self.attn_mlp(x) 28 | patch_x = x.reshape(x.shape[0], x.shape[1], -1) #(N D L) 29 | patch_x = patch_x.permute(0, 2, 1) #(N L D) 30 | xray_pred = torch.einsum('NQD,NDhw->NQhw', query, xray) 31 | n, d, h, w = xray.shape 32 | attn = attn.reshape(n, self.head_num, d, h, w) # (N Head*D,h,w)->(N Head D h w) 33 | attn_bias = torch.einsum('NQD,NHDhw->NHQhw', query, attn) 34 | attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1) 35 | return xray_pred, attn_bias 36 | 37 | 38 | class Adapter(nn.Module): 39 | def __init__(self, vit_name, num_quires, fusion_map, mlp_dim, mlp_out_dim, head_num, device): 40 | super().__init__() 41 | self.device = device 42 | self.vit_model = create_model(vit_name, 43 | pretrained=False, 44 | fc_norm=False, 45 | num_classes=0, 46 | embed_layer=PatchEmbed, 47 | ).to(device=self.device) 48 | 49 | if self.vit_model.cls_token is not None: 50 | self.vit_model.pos_embed = nn.Parameter(self.vit_model.pos_embed[:, 1:, ...]) # 去掉cls的位置 51 | del self.vit_model.cls_token 52 | self.vit_model.cls_token = None 53 | del self.vit_model.norm 54 | self.vit_model.norm = nn.Identity() 55 | self.num_quires = num_quires 56 | self.num_features = self.vit_model.num_features 57 | self.query_embed = nn.Parameter(torch.zeros(1, self.num_quires, self.num_features)) # (1,Q_L,D) 58 | self.query_pos_embed = nn.Parameter(torch.zeros(1, self.num_quires, self.num_features)) 59 | self.fusion_map = fusion_map 60 | nn.init.normal_(self.query_embed, std=0.02) 61 | nn.init.normal_(self.query_pos_embed, std=0.02) 62 | self.mask_decoder = Mask_Decoder(in_dim=self.num_features, mlp_dim=mlp_dim, out_dim=mlp_out_dim, 63 | mlp_num_layers=3, head_num=head_num) 64 | self.ln_pre = VT_LN(self.num_features) 65 | self.patch_conv = nn.Conv2d(in_channels=3, out_channels=self.num_features, kernel_size=16, stride=16, 66 | bias=True) 67 | 68 | def fuse(self, block_idx, x, clip_features, spatial_shape): 69 | if block_idx in self.fusion_map.keys(): 70 | clip_layer = self.fusion_map[block_idx] 71 | adapter_layer = block_idx 72 | clip_dim = clip_features[clip_layer].shape[2] # clip features NLD 73 | 74 | fusion = Fusion(clip_dim, self.num_features).to(self.device) 75 | L = spatial_shape[0] * spatial_shape[1] 76 | x = torch.cat( 77 | [ 78 | x[:, :-L, ...], # query 79 | # fuse vision token x(N,a_L,D) clip_f[i] (N,c_L,D) 80 | fusion(x[:, -L:, ...], clip_features[clip_layer], spatial_shape) 81 | ], 82 | dim=1 83 | ) 84 | return x 85 | 86 | def intra_contra(self, block_idx, x, patch_labels, image_labels, spatial_shape): 87 | loss = 0 88 | if block_idx == 4 or block_idx == 5 or block_idx == 6 : 89 | L = spatial_shape[0] * spatial_shape[1] # 14 * 14 = 196 90 | 91 | embeddings = x[:, -L:, ...] # (N 196 192) (NLD) 92 | embeddings = nn.functional.normalize(embeddings, dim=-1) 93 | fake_index = (image_labels == 1).nonzero(as_tuple=True)[0] 94 | fake_embeddings = embeddings[fake_index] # f_N L D 95 | fake_nums = len(fake_index) 96 | 97 | ff_patch_labels = patch_labels[fake_index] # f_N L 98 | 99 | 100 | fr_patch_labels = torch.logical_not(ff_patch_labels) 101 | 102 | f_fake_part = fake_embeddings * ff_patch_labels.unsqueeze(-1) # f_N L D * f_N L 1 -> f_N L D 103 | f_real_part = fake_embeddings * fr_patch_labels.unsqueeze(-1) # f_N L D * f_N L 1 -> f_N L D 104 | 105 | negative = torch.bmm(f_fake_part, f_real_part.permute(0, 2, 1)) / 0.5 # f_N L L 106 | positive1 = torch.bmm(f_real_part, f_real_part.permute(0, 2, 1)) / 0.5 # f_N L L 107 | positive2 = torch.bmm(f_fake_part, f_fake_part.permute(0, 2, 1)) / 0.5 # f_N L L 108 | 109 | l_neg = torch.sum(torch.exp(negative)) 110 | l_pos1 = torch.sum(torch.exp(positive1)) 111 | l_pos2 = torch.sum(torch.exp(positive2)) 112 | 113 | loss_real_intra = -torch.log(l_pos1 / (l_neg + l_pos1)) 114 | loss_fake_intra = -torch.log(l_pos2 / (l_neg + l_pos2)) 115 | 116 | 117 | #loss = loss_real_intra + loss_fake_intra 118 | loss = loss_real_intra 119 | 120 | return loss 121 | 122 | def forward(self, data_dict, clip_features, inference): 123 | 124 | image = data_dict['image'] 125 | x = self.patch_conv(image) 126 | x = x.reshape(x.shape[0], x.shape[1], -1) 127 | x = x.permute(0, 2, 1) 128 | 129 | pos_embed = self.vit_model.pos_embed # (N L D) 130 | pos_embed = pos_embed.permute(0, 2, 1) #(NDL) 131 | pos_embed = F.interpolate(pos_embed.reshape(pos_embed.shape[0],pos_embed.shape[1], 14, 14), 132 | size=(16, 16), 133 | mode='bilinear', 134 | align_corners=False).reshape(pos_embed.shape[0],pos_embed.shape[1],256).permute(0,2,1) #NDL->NLD 135 | pos_embed = torch.cat([self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed], 136 | dim=1) 137 | v_L = x.shape[1] # vision token L 196 138 | (h, w) = 16,16 # h w 16,16 139 | 140 | x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1) # (N ,Q_L+L,D) 141 | x = x + pos_embed 142 | x = self.ln_pre(x) 143 | outs = [] 144 | out_layers = [8] 145 | loss_intra = 0 146 | # self.fuse(0, x, clip_features, (h, w)) 147 | for i, block in enumerate(self.vit_model.blocks, start=1): # total 1-12 ,only use 1-8 148 | x = block(x) # (N, Q_L+L, D) 149 | self.fuse(i, x, clip_features, (h, w)) 150 | if not inference: #train 151 | loss_tmp_intra = self.intra_contra(i, x, data_dict['patch_label'], data_dict['label'], (h, w)) 152 | loss_intra = loss_tmp_intra + loss_intra if loss_tmp_intra != 0 else loss_intra 153 | if i in out_layers: 154 | n, _, d = x.shape 155 | outs.append( 156 | { 157 | 'query': x[:, :-v_L, ...], 158 | 'x': x[:, -v_L:, ...] 159 | .permute(0, 2, 1) 160 | .reshape(n, d, h, w), 161 | } 162 | ) 163 | x = x + pos_embed 164 | if i == max(out_layers): 165 | break 166 | xray_preds = [] 167 | attn_biases = [] 168 | 169 | for feature in outs: 170 | xray_pred, attn_bias = self.mask_decoder(feature['query'], feature['x']) 171 | xray_preds.append(xray_pred) 172 | attn_biases.append(attn_bias) 173 | 174 | return attn_biases, xray_preds, loss_intra 175 | -------------------------------------------------------------------------------- /model/attn.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from model.layer import MLP 8 | 9 | class ClipIntraBlock(nn.Module): 10 | def __init__(self,num_features): 11 | super().__init__() 12 | self.num_features = num_features 13 | self.conv_first =nn.Conv1d(in_channels=self.num_features, out_channels=192, kernel_size=1) 14 | self.relu = nn.ReLU() 15 | self.conv_second =nn.Conv1d(in_channels=192, out_channels=self.num_features, kernel_size=1) 16 | 17 | def forward(self, x, data_dict, clip_L, inference): 18 | 19 | intra_x = x.permute(1, 2, 0) # LND -> NDL 20 | intra_x = intra_x[:, :, -clip_L:].float() # N clipD 256/196 21 | 22 | intra_x = self.conv_first(intra_x) # N D 256 23 | intra_x = self.relu(intra_x) 24 | intra_x = intra_x.permute(0, 2, 1) # NDL-> NLD 25 | if not inference: 26 | loss_clip = self.intra_contra(intra_x, data_dict['clip_patch_label'], data_dict['label'], (16, 16)) 27 | else: 28 | loss_clip = 0 29 | intra_x = intra_x.permute(0, 2, 1) # NLD-> NDL 30 | intra_x = self.conv_second(intra_x) # NDL 31 | # intra_x = self.relu(intra_x) 32 | 33 | intra_x = intra_x.permute(2, 0, 1) # NDL- >LND 34 | # x LND 35 | # x[-clip_L:,...] = intra_x * 0.1 + x[-clip_L:, ...] * self.intra_scale 在B14 效果不错 36 | # x[-clip_L:, ...] = intra_x * self.intra_scale + x[-clip_L:, ...] * 0.9 37 | #x[-clip_L:, ...] = intra_x * 0.15 + x[-clip_L:, ...] * 0.95 38 | return intra_x, loss_clip 39 | 40 | def intra_contra(self, x, patch_labels, image_labels, spatial_shape): 41 | if True: 42 | L = spatial_shape[0] * spatial_shape[1] # 14 * 14 = 196 16 * 16=256 43 | 44 | embeddings = x[:, -L:, ...] # (N 196 192) (NLD) (N 256 128) 45 | embeddings = nn.functional.normalize(embeddings, dim=-1) 46 | fake_index = (image_labels == 1).nonzero(as_tuple=True)[0] 47 | fake_embeddings = embeddings[fake_index] # f_N L D 48 | 49 | ff_patch_labels = patch_labels[fake_index] # f_N L 50 | fr_patch_labels = torch.logical_not(ff_patch_labels) 51 | 52 | f_fake_part = fake_embeddings * ff_patch_labels.unsqueeze(-1) # f_N L D * f_N L 1 -> f_N L D 53 | f_real_part = fake_embeddings * fr_patch_labels.unsqueeze(-1) # f_N L D * f_N L 1 -> f_N L D 54 | 55 | negative = torch.bmm(f_fake_part, f_real_part.permute(0, 2, 1)) / 0.5 # f_N L L 56 | positive1 = torch.bmm(f_real_part, f_real_part.permute(0, 2, 1)) / 0.5 # f_N L L 57 | l_neg = torch.sum(torch.exp(negative)) 58 | 59 | l_pos1 = torch.sum(torch.exp(positive1)) 60 | loss_real_intra = -torch.log(l_pos1 / (l_neg + l_pos1)) 61 | #---------------fake---------------------- 62 | positive2 = torch.bmm(f_fake_part, f_fake_part.permute(0, 2, 1)) / 0.5 # f_N L L 63 | l_pos2 = torch.sum(torch.exp(positive2)) 64 | loss_fake_intra = -torch.log(l_pos2 / (l_neg + l_pos2)) 65 | #loss_intra = loss_fake_intra + loss_real_intra 66 | loss_intra = loss_real_intra 67 | 68 | real_index = (image_labels == 0).nonzero(as_tuple=True)[0] 69 | real_nums = len(real_index) 70 | fake_nums = len(fake_index) 71 | 72 | if real_nums != 0 and fake_nums != 0 : 73 | real_embeddings = embeddings[real_index] # N r_L D 74 | if fake_nums >= real_nums: 75 | random_fake_index = torch.randperm(fake_nums)[:real_nums] 76 | random_fake_embeddings = fake_embeddings[random_fake_index] 77 | real_neg = torch.bmm(real_embeddings, random_fake_embeddings.permute(0, 2, 1)) / 0.5 78 | real_pos = torch.bmm(real_embeddings, real_embeddings.permute(0, 2, 1)) / 0.5 79 | 80 | else: 81 | random_real_index = torch.randperm(real_nums)[:fake_nums] 82 | random_real_embeddings = real_embeddings[random_real_index] 83 | real_neg = torch.bmm(random_real_embeddings, fake_embeddings.permute(0, 2, 1)) / 0.5 84 | real_pos = torch.bmm(random_real_embeddings, random_real_embeddings.permute(0, 2, 1)) / 0.5 85 | loss_real_neg = torch.sum(torch.exp(real_neg)) 86 | loss_real_pos = torch.sum(torch.exp(real_pos)) 87 | loss_inter = -torch.log(loss_real_pos / (loss_real_pos + loss_real_neg)) 88 | loss_clip = loss_inter + loss_intra 89 | 90 | else: 91 | loss_clip = loss_intra 92 | 93 | return loss_clip 94 | 95 | 96 | class RecAttnClip(nn.Module): 97 | def __init__(self, vit, num_quires, device): 98 | super().__init__() 99 | self.vit = vit 100 | self.resblocks = self.vit.transformer.resblocks 101 | self.first_layer = 0 102 | self.clss_nums = num_quires 103 | self.ln_post = self.vit.ln_post 104 | self.proj = self.vit.proj 105 | self.num_features = self.vit.width 106 | self.device = device 107 | self.intra_scale = nn.Parameter(torch.zeros(1)) 108 | self.intra_map = {6:0} 109 | self.clip_intra_blocks = nn.ModuleList([ClipIntraBlock(self.num_features).to(self.device) for _ in range(1)]) 110 | self._freeze() 111 | 112 | def build_attn_mask(self, attn_bias): 113 | 114 | num_heads = self.resblocks[0].attn.num_heads 115 | n, Head, q, h, w = attn_bias.shape 116 | 117 | assert ( 118 | Head == num_heads 119 | ), f"num_head={Head} is not supported. Modify to {num_heads}" 120 | attn_bias = attn_bias.reshape(n * Head, q, -1) 121 | l = attn_bias.shape[-1] 122 | attn_mask = attn_bias.new_zeros(q + 1 + l, q + 1 + l) 123 | attn_mask[:, :q] = -100 124 | attn_mask[torch.arange(q), torch.arange(q)] = 0 125 | attn_mask[:q, q] = -100 126 | attn_mask = attn_mask[None, ...].expand( 127 | n * Head, -1, -1 128 | ).clone() 129 | attn_mask[:, :q, -l:] = attn_bias 130 | # attn_mask (n*head,1+q+l,1+q+l) 131 | attn_biases = [attn_mask for _ in self.resblocks.children()] 132 | return attn_biases 133 | 134 | def _freeze(self): 135 | for name, param in self.named_parameters(): 136 | if 'clip_intra_blocks' in name : 137 | param.requires_grad = True 138 | else: 139 | param.requires_grad = False 140 | 141 | 142 | def forward(self, data_dict, clip_features, attn_bias,inference=False, normalize=False): 143 | cls_token = clip_features[f'layer_{self.first_layer}_cls'].unsqueeze(1).permute(1, 0, 2).clone() # ND->N1D->1ND 144 | vision_tokens = clip_features[self.first_layer].permute(1, 0, 2).clone() # NLD->LND 145 | clss_token = cls_token.repeat(self.clss_nums, 1, 1) # 1ND -> clss_nums, N,D 146 | 147 | x = torch.cat( 148 | [ 149 | clss_token, 150 | cls_token, 151 | vision_tokens 152 | ], 153 | dim=0 154 | ) # (1+Q+L,N,D) 155 | x.requires_grad = True 156 | clip_L = vision_tokens.shape[0] # 157 | 158 | attn_biases = self.build_attn_mask(attn_bias) 159 | 160 | loss_clip = 0 161 | for i, blocks in enumerate(self.resblocks.children()): 162 | 163 | x = blocks(x, attn_biases[i]) 164 | if i == 6: 165 | intra_x, loss_clip_tmp = self.clip_intra_blocks[self.intra_map[i]](x, data_dict, clip_L, inference) 166 | loss_clip = loss_clip_tmp + loss_clip 167 | x[-clip_L:, ...] = intra_x * 0.05 + x[-clip_L:, ...] 168 | 169 | x = x.permute(1, 0, 2) # LND -> NLD 170 | clss_token = x[:, :self.clss_nums, :] 171 | clss_token = self.ln_post(clss_token) 172 | if self.proj is not None: 173 | clss_token = clss_token @ self.proj 174 | if normalize: 175 | clss_token = F.normalize(clss_token, dim=-1) 176 | 177 | return clss_token, loss_clip 178 | -------------------------------------------------------------------------------- /model/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /model/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OUC-VAS/ForensicsAdapter/b510b1db9fa57531e269ff4cfb6654444607592a/model/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /model/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 加载已训练好的state_dict权重到模型中 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | # build model 140 | model = build_model(state_dict or model.state_dict()).to(device) 141 | if str(device) == "cpu": 142 | model.float() 143 | return model, _transform(model.visual.input_resolution) 144 | 145 | # patch the device names 146 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 147 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 148 | 149 | def patch_device(module): 150 | try: 151 | graphs = [module.graph] if hasattr(module, "graph") else [] 152 | except RuntimeError: 153 | graphs = [] 154 | 155 | if hasattr(module, "forward1"): 156 | graphs.append(module.forward1.graph) 157 | 158 | for graph in graphs: 159 | for node in graph.findAllNodes("prim::Constant"): 160 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 161 | node.copyAttributes(device_node) 162 | 163 | model.apply(patch_device) 164 | patch_device(model.encode_image) 165 | patch_device(model.encode_text) 166 | 167 | # patch dtype to float32 on CPU 168 | if str(device) == "cpu": 169 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 170 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 171 | float_node = float_input.node() 172 | 173 | def patch_float(module): 174 | try: 175 | graphs = [module.graph] if hasattr(module, "graph") else [] 176 | except RuntimeError: 177 | graphs = [] 178 | 179 | if hasattr(module, "forward1"): 180 | graphs.append(module.forward1.graph) 181 | 182 | for graph in graphs: 183 | for node in graph.findAllNodes("aten::to"): 184 | inputs = list(node.inputs()) 185 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 186 | if inputs[i].node()["value"] == 5: 187 | inputs[i].node().copyAttributes(float_node) 188 | 189 | model.apply(patch_float) 190 | patch_float(model.encode_image) 191 | patch_float(model.encode_text) 192 | 193 | model.float() 194 | 195 | return model, _transform(model.input_resolution.item()) 196 | 197 | 198 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 199 | """ 200 | Returns the tokenized representation of given input string(s) 201 | 202 | Parameters 203 | ---------- 204 | texts : Union[str, List[str]] 205 | An input string or a list of input strings to tokenize 206 | 207 | context_length : int 208 | The context length to use; all CLIP models use 77 as the context length 209 | 210 | truncate: bool 211 | Whether to truncate the text in case its encoding is longer than the context length 212 | 213 | Returns 214 | ------- 215 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 216 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 217 | """ 218 | if isinstance(texts, str): 219 | texts = [texts] 220 | 221 | sot_token = _tokenizer.encoder["<|startoftext|>"] 222 | eot_token = _tokenizer.encoder["<|endoftext|>"] 223 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 224 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 225 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 226 | else: 227 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 228 | 229 | for i, tokens in enumerate(all_tokens): 230 | if len(tokens) > context_length: 231 | if truncate: 232 | tokens = tokens[:context_length] 233 | tokens[-1] = eot_token 234 | else: 235 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 236 | result[i, :len(tokens)] = torch.tensor(tokens) 237 | 238 | return result 239 | -------------------------------------------------------------------------------- /model/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | class ResidualAttentionBlock(nn.Module): 171 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 172 | super().__init__() 173 | 174 | self.attn = nn.MultiheadAttention(d_model, n_head) 175 | 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | # self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor, attn_mask): 186 | attn_mask = attn_mask.to(dtype=x.dtype, device=x.device) if attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor, attn_mask=None): 190 | x = x + self.attention(self.ln_1(x), attn_mask) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class VTransformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | print(f'width:{width} layers {layers} ') 203 | 204 | def forward(self, x: torch.Tensor, extract=None): 205 | out = {} 206 | features = {} 207 | if extract is not None: 208 | features['layer_0_cls'] = x[0] # ND 209 | features[0] = x[1:].permute(1, 0, 2) # LND -> NLD 210 | for idx, layer in enumerate(self.resblocks.children(), start=1): 211 | x = layer(x) 212 | out['layer' + str(idx)] = x[0] 213 | if extract is not None and idx in extract: 214 | features[f'layer_{idx}_cls'] = x[0] # ND 215 | features[idx] = x[1:].permute(1, 0, 2) # NLD 216 | if idx == max(extract): 217 | return features 218 | return out, x 219 | 220 | # return self.resblocks(x) # This is the original code 221 | 222 | 223 | class Transformer(nn.Module): 224 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 225 | super().__init__() 226 | self.width = width 227 | self.layers = layers 228 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 229 | 230 | def forward(self, x: torch.Tensor): 231 | #out = {} 232 | #for idx, layer in enumerate(self.resblocks.children()): 233 | # x = layer(x) 234 | # out['layer' + str(idx)] = x[0] # shape:LND. choose cls token feature 选择 [CLS] 标记的特征 235 | #return out, x 236 | 237 | return self.resblocks(x) # This is the original code 238 | 239 | 240 | class VisionTransformer(nn.Module): 241 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 242 | super().__init__() 243 | self.input_resolution = input_resolution 244 | self.output_dim = output_dim 245 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 246 | scale = width ** -0.5 247 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 248 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 249 | self.ln_pre = LayerNorm(width) 250 | 251 | self.transformer = VTransformer(width, layers, heads) 252 | self.width=width 253 | self.ln_post = LayerNorm(width) 254 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 255 | self.extract = False 256 | 257 | def forward(self, x: torch.Tensor, extract=None): 258 | x = self.conv1(x) 259 | x = x.reshape(x.shape[0], x.shape[1], -1) 260 | x = x.permute(0, 2, 1) 261 | x = torch.cat( 262 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 263 | x], dim=1) # 264 | x = x + self.positional_embedding.to(x.dtype) 265 | x = self.ln_pre(x) # layer_norm 266 | 267 | x = x.permute(1, 0, 2) # NLD -> LND 268 | if extract is not None: 269 | feat = self.transformer(x, extract) 270 | return feat # NLD 271 | else: 272 | out, x = self.transformer(x) 273 | x = x.permute(1, 0, 2) # LND -> NLD 274 | x = self.ln_post(x[:, 0, :]) 275 | out['before_projection'] = x 276 | 277 | if self.proj is not None: 278 | x = x @ self.proj 279 | out['after_projection'] = x 280 | # Return both intermediate features and final clip feature 281 | # return out 282 | 283 | # This only returns CLIP features 284 | return x 285 | 286 | 287 | class CLIP(nn.Module): 288 | def __init__(self, 289 | embed_dim: int, 290 | # vision 291 | image_resolution: int, 292 | vision_layers: Union[Tuple[int, int, int, int], int], 293 | vision_width: int, 294 | vision_patch_size: int, 295 | # text 296 | context_length: int, 297 | vocab_size: int, 298 | transformer_width: int, 299 | transformer_heads: int, 300 | transformer_layers: int 301 | ): 302 | super().__init__() 303 | 304 | self.context_length = context_length 305 | # Resnet 306 | if isinstance(vision_layers, (tuple, list)): 307 | vision_heads = vision_width * 32 // 64 308 | self.visual = ModifiedResNet( 309 | layers=vision_layers, 310 | output_dim=embed_dim, 311 | heads=vision_heads, 312 | input_resolution=image_resolution, 313 | width=vision_width 314 | ) 315 | # VIT 316 | else: 317 | vision_heads = vision_width // 64 318 | self.visual = VisionTransformer( 319 | input_resolution=image_resolution, 320 | patch_size=vision_patch_size, 321 | width=vision_width, 322 | layers=vision_layers, 323 | heads=vision_heads, 324 | output_dim=embed_dim, 325 | ) 326 | # text 327 | self.transformer = Transformer( 328 | width=transformer_width, 329 | layers=transformer_layers, 330 | heads=transformer_heads, 331 | attn_mask=self.build_attention_mask() 332 | ) 333 | 334 | self.vocab_size = vocab_size 335 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 336 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 337 | self.ln_final = LayerNorm(transformer_width) 338 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 339 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 340 | 341 | self.initialize_parameters() 342 | 343 | def initialize_parameters(self): 344 | nn.init.normal_(self.token_embedding.weight, std=0.02) 345 | nn.init.normal_(self.positional_embedding, std=0.01) 346 | 347 | if isinstance(self.visual, ModifiedResNet): 348 | if self.visual.attnpool is not None: 349 | std = self.visual.attnpool.c_proj.in_features ** -0.5 350 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 351 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 352 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 353 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 354 | 355 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 356 | for name, param in resnet_block.named_parameters(): 357 | if name.endswith("bn3.weight"): 358 | nn.init.zeros_(param) 359 | 360 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 361 | attn_std = self.transformer.width ** -0.5 362 | fc_std = (2 * self.transformer.width) ** -0.5 363 | for block in self.transformer.resblocks: 364 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 365 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 366 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 367 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 368 | 369 | if self.text_projection is not None: 370 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 371 | 372 | def build_attention_mask(self): 373 | # lazily create causal attention mask, with full attention between the vision tokens 374 | # pytorch uses additive attention mask; fill with -inf 375 | mask = torch.empty(self.context_length, self.context_length) 376 | mask.fill_(float("-inf")) 377 | mask.triu_(1) # zero out the lower diagonal 378 | return mask 379 | 380 | @property 381 | def dtype(self): 382 | return self.visual.conv1.weight.dtype 383 | 384 | # get_image_embedding 385 | def encode_image(self, image): 386 | return self.visual(image.type(self.dtype)) 387 | 388 | def extract_features(self, image, extract): 389 | return self.visual(image.type(self.dtype), extract) # extract_features,x,cls 390 | 391 | # get_text_embedding 392 | def encode_text(self, text): 393 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 394 | x = x + self.positional_embedding.type(self.dtype) 395 | x = x.permute(1, 0, 2) # NLD -> LND 396 | x = self.transformer(x) 397 | x = x.permute(1, 0, 2) # LND -> NLD 398 | x = self.ln_final(x).type(self.dtype) # LayerNorm 399 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 400 | return x 401 | 402 | def forward(self, image, text): 403 | image_features = self.encode_image(image) 404 | text_features = self.encode_text(text) 405 | 406 | # normalized features 407 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 408 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 409 | 410 | # cosine similarity as logits 411 | logit_scale = self.logit_scale.exp() 412 | logits_per_image = logit_scale * image_features @ text_features.t() 413 | logits_per_text = logits_per_image.t() 414 | 415 | # shape = [global_batch_size, global_batch_size] 416 | return logits_per_image, logits_per_text 417 | 418 | 419 | def convert_weights(model: nn.Module): 420 | """Convert applicable model parameters to fp16""" 421 | 422 | def _convert_weights_to_fp16(l): 423 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 424 | l.weight.data = l.weight.data.half() 425 | if l.bias is not None: 426 | l.bias.data = l.bias.data.half() 427 | 428 | if isinstance(l, nn.MultiheadAttention): 429 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 430 | tensor = getattr(l, attr) 431 | if tensor is not None: 432 | tensor.data = tensor.data.half() 433 | 434 | for name in ["text_projection", "proj"]: 435 | if hasattr(l, name): 436 | attr = getattr(l, name) 437 | if attr is not None: 438 | attr.data = attr.data.half() 439 | 440 | model.apply(_convert_weights_to_fp16) 441 | 442 | 443 | def build_model(state_dict: dict): 444 | vit = "visual.proj" in state_dict 445 | 446 | if vit: 447 | vision_width = state_dict["visual.conv1.weight"].shape[0] 448 | vision_layers = len( 449 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 450 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 451 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 452 | image_resolution = vision_patch_size * grid_size 453 | else: 454 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in 455 | [1, 2, 3, 4]] 456 | vision_layers = tuple(counts) 457 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 458 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 459 | vision_patch_size = None 460 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 461 | image_resolution = output_width * 32 462 | 463 | embed_dim = state_dict["text_projection"].shape[1] 464 | context_length = state_dict["positional_embedding"].shape[0] 465 | vocab_size = state_dict["token_embedding.weight"].shape[0] 466 | transformer_width = state_dict["ln_final.weight"].shape[0] 467 | transformer_heads = transformer_width // 64 468 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 469 | 470 | model = CLIP( 471 | embed_dim, 472 | image_resolution, vision_layers, vision_width, vision_patch_size, 473 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 474 | ) 475 | 476 | for key in ["input_resolution", "context_length", "vocab_size"]: 477 | if key in state_dict: 478 | del state_dict[key] 479 | 480 | convert_weights(model) 481 | model.load_state_dict(state_dict) 482 | return model.eval() 483 | -------------------------------------------------------------------------------- /model/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /model/ds.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | from torch import nn 4 | from model.clip.clip import load 5 | import torch 6 | from model.adapters.adapter import Adapter 7 | from .attn import RecAttnClip 8 | from .layer import PostClipProcess, MaskPostXrayProcess 9 | import torch.nn.functional as F 10 | from trainer.metrics.base_metrics_class import calculate_metrics_for_train 11 | 12 | 13 | class DS(nn.Module): 14 | def __init__(self, clip_name, 15 | adapter_vit_name, 16 | num_quires, 17 | fusion_map, 18 | mlp_dim, 19 | mlp_out_dim, 20 | head_num, 21 | device, 22 | mode='video'): 23 | super().__init__() 24 | self.device = device 25 | self.clip_model, self.processor = load(clip_name, device=device,download_root='/data/cuixinjie/weights') 26 | self.adapter = Adapter(vit_name=adapter_vit_name, num_quires=num_quires, fusion_map=fusion_map, mlp_dim=mlp_dim, 27 | mlp_out_dim=mlp_out_dim, head_num=head_num, device=self.device) 28 | self.rec_attn_clip = RecAttnClip(self.clip_model.visual, num_quires,device=self.device) # 全部参数被冻结 29 | self.masked_xray_post_process = MaskPostXrayProcess(in_c=num_quires).to(self.device) 30 | self.clip_post_process = PostClipProcess(num_quires=num_quires, embed_dim=768) 31 | 32 | self.prob, self.label = [], [] 33 | self.correct, self.total = 0, 0 34 | self.mode = mode 35 | self._freeze() 36 | 37 | 38 | def _freeze(self): 39 | for name, param in self.named_parameters(): 40 | if 'clip_model' in name : 41 | param.requires_grad = False 42 | 43 | def get_losses(self, data_dict, pred_dict): 44 | label = data_dict['label'] #N 45 | xray = data_dict['xray'] 46 | pred = pred_dict['cls'] #N2 47 | xray_pred = pred_dict['xray_pred'] 48 | loss_intra = pred_dict['loss_intra'] 49 | loss_clip = pred_dict['loss_clip'] 50 | criterion = nn.CrossEntropyLoss() 51 | loss1 = criterion(pred.float(), label) 52 | if xray is not None: 53 | loss_mse = F.mse_loss(xray_pred.squeeze().float(), xray.squeeze().float()) # (N 1 224 224)->(N 224 224) 54 | 55 | loss = 10 * loss1 + 200 * loss_mse + 20 * loss_intra + 10 * loss_clip 56 | 57 | 58 | loss_dict = { 59 | 'cls': loss1, 60 | 'xray': loss_mse, 61 | 'intra': loss_intra, 62 | 'loss_clip':loss_clip, 63 | 'overall': loss 64 | } 65 | return loss_dict 66 | else: 67 | loss_dict = { 68 | 'cls': loss1, 69 | 'overall': loss1 70 | } 71 | return loss_dict 72 | 73 | def get_train_metrics(self, data_dict, pred_dict): 74 | label = data_dict['label'] 75 | pred = pred_dict['cls'] 76 | auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) 77 | metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} 78 | return metric_batch_dict 79 | 80 | def get_test_metrics(self): 81 | y_pred = np.concatenate(self.prob) 82 | y_true = np.concatenate(self.label) 83 | # auc 84 | fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) 85 | auc = metrics.auc(fpr, tpr) 86 | # eer 87 | fnr = 1 - tpr 88 | eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 89 | # ap 90 | ap = metrics.average_precision_score(y_true, y_pred) 91 | # acc 92 | acc = self.correct / self.total 93 | # reset the prob and label 94 | self.prob, self.label = [], [] 95 | self.correct, self.total = 0, 0 96 | return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'label': y_true} 97 | 98 | def forward(self, data_dict, inference=False): 99 | images = data_dict['image'] 100 | clip_images = F.interpolate( 101 | images, 102 | size=(224, 224), 103 | mode='bilinear', 104 | align_corners=False, 105 | ) 106 | 107 | 108 | clip_features = self.clip_model.extract_features(clip_images, self.adapter.fusion_map.values()) 109 | 110 | attn_biases, xray_preds, loss_adapter_intra = self.adapter(data_dict, clip_features, 111 | inference) 112 | clip_output, loss_clip = self.rec_attn_clip(data_dict, clip_features, attn_biases[-1], inference, normalize=True) 113 | 114 | data_dict['if_boundary'] = data_dict['if_boundary'].to(self.device) 115 | xray_preds = [self.masked_xray_post_process(xray_pred, data_dict['if_boundary']) for xray_pred in xray_preds] 116 | 117 | clip_cls_output = self.clip_post_process(clip_output.float()).squeeze() # N2 118 | 119 | outputs = { 120 | 'xray_pred': xray_preds[-1], # N 1 224 224 121 | 'clip_cls_output': clip_cls_output, # N 2 122 | 123 | } 124 | 125 | prob = torch.softmax(outputs['clip_cls_output'], dim=1)[:, 1] 126 | pred_dict = { 127 | 'cls': outputs['clip_cls_output'], 128 | 'prob': prob, 129 | 'xray_pred': outputs['xray_pred'], 130 | 'loss_intra': loss_adapter_intra, 131 | 'loss_clip':loss_clip, 132 | } 133 | 134 | if inference: 135 | self.prob.append( 136 | pred_dict['prob'] 137 | .detach() 138 | .squeeze() 139 | .cpu() 140 | .numpy() 141 | ) 142 | self.label.append( 143 | data_dict['label'] 144 | .detach() 145 | .squeeze() 146 | .cpu() 147 | .numpy() 148 | ) 149 | # deal with acc 150 | _, prediction_class = torch.max(outputs['clip_cls_output'], 1) 151 | correct = (prediction_class == data_dict['label']).sum().item() 152 | self.correct += correct 153 | self.total += data_dict['label'].size(0) 154 | 155 | return pred_dict 156 | -------------------------------------------------------------------------------- /model/layer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | import warnings 5 | from timm.models.layers import to_2tuple 6 | 7 | 8 | class LayerNorm(nn.Module): 9 | """ 10 | A LayerNorm variant, popularized by Transformers, that performs point-wise mean and 11 | variance normalization over the channel dimension for inputs that have shape 12 | (batch_size, channels, height, width). 13 | https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 14 | """ 15 | 16 | def __init__(self, normalized_shape, eps=1e-6): 17 | super().__init__() 18 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 19 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 20 | self.eps = eps 21 | self.normalized_shape = (normalized_shape,) 22 | 23 | def forward(self, x: torch.Tensor): 24 | u = x.mean(1, keepdim=True) 25 | s = (x - u).pow(2).mean(1, keepdim=True) 26 | x = (x - u) / torch.sqrt(s + self.eps) 27 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 28 | return x 29 | 30 | 31 | class MLP(nn.Module): 32 | """Very simple multi-layer perceptron (also called FFN)""" 33 | 34 | def __init__( 35 | self, input_dim, hidden_dim, output_dim, num_layers, affine_func=nn.Linear 36 | ): 37 | super().__init__() 38 | self.num_layers = num_layers 39 | h = [hidden_dim] * (num_layers - 1) 40 | self.layers = nn.ModuleList( 41 | affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 42 | ) 43 | 44 | def forward(self, x: torch.Tensor): 45 | for i, layer in enumerate(self.layers): 46 | # L R L R L 47 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 48 | return x 49 | 50 | 51 | class Fusion(nn.Module): 52 | def __init__(self, clip_dim, adapter_dim): 53 | super().__init__() 54 | self.clip_dim = clip_dim 55 | self.adapter_dim = adapter_dim 56 | self.proj = nn.Sequential( 57 | LayerNorm(clip_dim), 58 | nn.Conv2d(clip_dim, adapter_dim, kernel_size=1), 59 | ) 60 | 61 | def forward(self, x, clip_x, spatial_shape): 62 | h, w = spatial_shape 63 | n, l, d = clip_x.shape 64 | 65 | if l == h*w: 66 | clip_x = clip_x.permute(0, 2, 1).view(n, d, h, w) # NLD->NDL->NDhw 67 | else: 68 | clip_x = clip_x.permute(0, 2, 1).view(n, d, 14, 14) # NLD->NDL->NDhw 69 | clip_x = F.interpolate( 70 | clip_x.contiguous(), 71 | size= (16, 16), 72 | mode="bilinear", 73 | align_corners=False, 74 | ) #ND 14 14 => N D 16 16 75 | clip_x = self.proj(clip_x).view(n, self.adapter_dim, h * w).permute(0, 2, 1) 76 | x = x + clip_x # NLD 77 | 78 | return x 79 | 80 | 81 | 82 | class MaskPostXrayProcess(nn.Module): 83 | def __init__(self, in_c): 84 | super().__init__() 85 | 86 | self.process = nn.Sequential( 87 | nn.Conv2d(in_channels=in_c, out_channels=in_c // 2, kernel_size=3, stride=1, padding=1), # (N Q h,w)->(N 64 h,w)) 88 | nn.BatchNorm2d(in_c // 2), 89 | nn.ReLU(), 90 | nn.Conv2d(in_channels=in_c // 2, out_channels=in_c // 4, kernel_size=3, stride=1, padding=1), # (N 32 h,w) 91 | nn.BatchNorm2d(in_c // 4), 92 | nn.ReLU(), 93 | nn.Conv2d(in_channels=in_c // 4, out_channels=1, kernel_size=1, stride=1, padding=0), # (N 16 h,w) 94 | 95 | nn.ConvTranspose2d(in_channels=1, out_channels=1 , kernel_size=16, stride=16) # (N 16 h,w)->(N 1 256 256) 96 | #nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True) # (N 1 256 256) 97 | ) 98 | 99 | 100 | def forward(self, x, if_boundaries): 101 | x = x.reshape(x.shape[0], x.shape[1], -1) #(N Q 256) 102 | x = x.permute(0, 2, 1)#(N L Q) 103 | if_boundaries = if_boundaries.unsqueeze(-1) # (NL1) 不是boundry的patch块置为0 104 | 105 | x = x * if_boundaries #(N L Q) * (N L 1) 106 | x = x.permute(0, 2, 1) #(N Q L) 107 | x = x.reshape(x.shape[0], x.shape[1], 16, 16) 108 | 109 | post_x = self.process(x) #(N 1 224 224) 110 | return post_x 111 | 112 | 113 | 114 | class PostClipProcess(nn.Module): 115 | """ 116 | NQD -> ND -> N2 117 | 118 | """ 119 | 120 | def __init__(self, num_quires, embed_dim): 121 | super().__init__() 122 | 123 | self.first_process = nn.Sequential( 124 | nn.Conv1d(in_channels=num_quires, out_channels=num_quires // 2, kernel_size=3, stride=1, padding=1), # NQD->N1D 125 | nn.BatchNorm1d(num_quires // 2), 126 | nn.ReLU(), 127 | nn.Conv1d(in_channels=num_quires // 2, out_channels=num_quires // 4 , kernel_size=3, stride=1, padding=1), 128 | nn.BatchNorm1d(num_quires // 4), 129 | nn.ReLU(), 130 | nn.Conv1d(in_channels=num_quires // 4, out_channels=1, kernel_size=3, stride=1, padding=1), 131 | 132 | ) 133 | #self.norm = VT_LN(embed_dim) 134 | self.second_process = nn.Sequential( # ND->N2 135 | nn.Linear(in_features=embed_dim, out_features=embed_dim // 2), 136 | nn.ReLU(), 137 | nn.Linear(in_features=embed_dim // 2, out_features=embed_dim // 4), 138 | nn.ReLU(), 139 | #nn.Linear(in_features=embed_dim // 4, out_features=embed_dim // 8), 140 | #nn.ReLU(), 141 | nn.Linear(in_features=embed_dim // 4, out_features=2) 142 | ) 143 | 144 | 145 | def forward(self, x): 146 | x = self.first_process(x) #NQD->N1D 147 | x = x.squeeze() 148 | x = self.second_process(x) 149 | return x 150 | 151 | 152 | 153 | class VT_LN(nn.LayerNorm): 154 | def forward(self, x: torch.Tensor): 155 | orig_type = x.dtype 156 | ret = super().forward(x.type(torch.float32)) 157 | return ret.type(orig_type) 158 | 159 | 160 | class PatchEmbed(nn.Module): 161 | 162 | def __init__( 163 | self, 164 | img_size=256, 165 | patch_size=16, 166 | in_chans=3, 167 | embed_dim=192, 168 | norm_layer=None, 169 | bias=False, 170 | **kwargs 171 | ): 172 | super().__init__() 173 | img_size = to_2tuple(img_size) 174 | patch_size = to_2tuple(patch_size) 175 | self.img_size = img_size 176 | self.patch_size = patch_size 177 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 178 | self.num_patches = self.grid_size[0] * self.grid_size[1] 179 | 180 | self.proj = nn.Conv2d( 181 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias 182 | ) 183 | 184 | self.norm = VT_LN(embed_dim) 185 | 186 | def forward(self, x): 187 | x = self.proj(x) 188 | x = x.reshape(x.shape[0], x.shape[1], -1) # NDL 189 | x = x.permute(0, 2, 1) # NDL->NLD 190 | # x = self.norm(x) 191 | return x 192 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | eval pretained model. 3 | """ 4 | import numpy as np 5 | import random 6 | import yaml 7 | from tqdm import tqdm 8 | from trainer.metrics.utils import get_test_metrics 9 | import torch 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.utils.data 13 | 14 | from dataset.abstract_dataset import DeepfakeAbstractBaseDataset 15 | from model.ds import DS 16 | 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser(description='Process some paths.') 20 | parser.add_argument('--detector_path', type=str, 21 | default='/data/cuixinjie/FA/config/test.yaml', 22 | help='path to detector YAML file') 23 | parser.add_argument("--test_dataset", nargs="+") 24 | parser.add_argument('--weights_path', type=str, 25 | # ds_ _2024-12-20-21-48-57ds_ _2024-12-30-18-07-52 26 | #ds_ _2024-12-22-15-55-57 FFIW 83 71 WDF auc: 0.8351 video_auc: 0.8747 DF10 video_auc: 0.98225 auc: 0.961988821 27 | #ds_ _2024-12-26-16-47-41 WDF 85 86 DF10 0.95505 video_auc: 0.97841 28 | default='/data/cuixinjie/DsClip_L_V2/logs/ds_ _2024-09-25-14-26-04/test/avg/ckpt_best.pth') # 29 | #parser.add_argument("--lmdb", action='store_true', default=False) 30 | args = parser.parse_args() 31 | 32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | 34 | def init_seed(config): 35 | if config['manualSeed'] is None: 36 | config['manualSeed'] = random.randint(1, 10000) 37 | random.seed(config['manualSeed']) 38 | torch.manual_seed(config['manualSeed']) 39 | if config['cuda']: 40 | torch.cuda.manual_seed_all(config['manualSeed']) 41 | 42 | 43 | def prepare_testing_data(config): 44 | def get_test_data_loader(config, test_name): 45 | # update the config dictionary with the specific testing dataset 46 | config = config.copy() # create a copy of config to avoid altering the original one 47 | config['test_dataset'] = test_name # specify the current test dataset 48 | test_set = DeepfakeAbstractBaseDataset( 49 | config=config, 50 | mode='test', 51 | ) 52 | test_data_loader = \ 53 | torch.utils.data.DataLoader( 54 | dataset=test_set, 55 | batch_size=config['test_batchSize'], 56 | shuffle=False, 57 | num_workers=int(config['workers']), 58 | collate_fn=test_set.collate_fn, 59 | drop_last=False 60 | ) 61 | return test_data_loader 62 | 63 | test_data_loaders = {} 64 | for one_test_name in config['test_dataset']: 65 | test_data_loaders[one_test_name] = get_test_data_loader(config, one_test_name) 66 | return test_data_loaders 67 | 68 | 69 | def choose_metric(config): 70 | metric_scoring = config['metric_scoring'] 71 | if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: 72 | raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) 73 | return metric_scoring 74 | 75 | 76 | def test_one_dataset(model, data_loader): 77 | prediction_lists = [] 78 | #feature_lists = [] 79 | label_lists = [] 80 | for i, data_dict in tqdm(enumerate(data_loader), total=len(data_loader)): 81 | # get data 82 | data, label, mask, landmark = \ 83 | data_dict['image'], data_dict['label'], data_dict['mask'], data_dict['landmark'] 84 | label = torch.where(data_dict['label'] != 0, 1, 0) 85 | # move data to GPU 86 | data_dict['image'], data_dict['label'] = data.to(device), label.to(device) 87 | if mask is not None: 88 | data_dict['mask'] = mask.to(device) 89 | if landmark is not None: 90 | data_dict['landmark'] = landmark.to(device) 91 | 92 | # model forward without considering gradient computation 93 | predictions = inference(model, data_dict) 94 | label_lists += list(data_dict['label'].cpu().detach().numpy()) 95 | prediction_lists += list(predictions['prob'].cpu().detach().numpy()) 96 | #feature_lists += list(predictions['feat'].cpu().detach().numpy()) 97 | 98 | return np.array(prediction_lists), np.array(label_lists)#,np.array(feature_lists) 99 | 100 | def test_epoch(model, test_data_loaders): 101 | # set model to eval mode 102 | model.eval() 103 | 104 | # define test recorder 105 | metrics_all_datasets = {} 106 | 107 | # testing for all test data 108 | keys = test_data_loaders.keys() 109 | for key in keys: 110 | data_dict = test_data_loaders[key].dataset.data_dict 111 | # compute loss for each dataset 112 | predictions_nps, label_nps = test_one_dataset(model, test_data_loaders[key]) 113 | print(f'name {data_dict.keys()}') 114 | # compute metric for each dataset 115 | metric_one_dataset = get_test_metrics(y_pred=predictions_nps, y_true=label_nps, 116 | img_names=data_dict['image']) 117 | metrics_all_datasets[key] = metric_one_dataset 118 | 119 | # info for each dataset 120 | tqdm.write(f"dataset: {key}") 121 | for k, v in metric_one_dataset.items(): 122 | tqdm.write(f"{k}: {v}") 123 | 124 | return metrics_all_datasets 125 | 126 | @torch.no_grad() 127 | def inference(model, data_dict): 128 | predictions = model(data_dict, inference=True) 129 | return predictions 130 | 131 | 132 | def main(): 133 | # parse options and load config 134 | with open(args.detector_path, 'r') as f: 135 | config = yaml.safe_load(f) 136 | 137 | weights_path = None 138 | # If arguments are provided, they will overwrite the yaml settings 139 | if args.test_dataset: 140 | config['test_dataset'] = args.test_dataset 141 | if args.weights_path: 142 | config['weights_path'] = args.weights_path 143 | weights_path = args.weights_path 144 | 145 | # init seed 146 | init_seed(config) 147 | 148 | # set cudnn benchmark if needed 149 | if config['cudnn']: 150 | cudnn.benchmark = True 151 | 152 | # prepare the testing data loader 153 | test_data_loaders = prepare_testing_data(config) 154 | 155 | # prepare the model (detector) 156 | 157 | model = DS(clip_name=config['clip_model_name'], 158 | adapter_vit_name=config['vit_name'], 159 | num_quires=config['num_quires'], 160 | fusion_map=config['fusion_map'], 161 | mlp_dim=config['mlp_dim'], 162 | mlp_out_dim=config['mlp_out_dim'], 163 | head_num=config['head_num'], 164 | device=config['device']) 165 | epoch = 0 166 | #weights_paths = [ 167 | # '/data/cuixinjie/DsClip_L_V2/logs/ds_ _2024-09-25-14-26-04/test/avg/ckpt_best.pth', 168 | # ] 169 | 170 | try: 171 | epoch = int(weights_path.split('/')[-1].split('.')[0].split('_')[2]) 172 | except: 173 | epoch = 0 174 | ckpt = torch.load(weights_path, map_location=device) 175 | model.load_state_dict(ckpt, strict=False) 176 | model.to(device) 177 | 178 | print(f'===> Load {weights_path} done!') 179 | 180 | # start testing 181 | best_metric = test_epoch(model, test_data_loaders) 182 | print('===> Test Done!') 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import datetime 5 | import yaml 6 | from datetime import timedelta 7 | import torch 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.utils.data 11 | import torch.optim as optim 12 | from torch.utils.data.distributed import DistributedSampler 13 | import torch.distributed as dist 14 | 15 | from dataset.abstract_dataset import DeepfakeAbstractBaseDataset 16 | from model.ds import DS 17 | from trainer.trainer import Trainer 18 | from trainer.metrics.utils import parse_metric_for_print 19 | from logger import create_logger, RankFilter 20 | 21 | parser = argparse.ArgumentParser(description='Process some paths.') 22 | parser.add_argument('--config_path', type=str, 23 | default='/data/cuixinjie/FA/config/train.yaml', 24 | help='path to detector YAML file') 25 | parser.add_argument("--train_dataset", nargs="+") 26 | parser.add_argument("--test_dataset", nargs="+") 27 | parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True) 28 | parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=False) 29 | parser.add_argument("--ddp", action='store_true', default=False) 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | parser.add_argument('--task_target', type=str, default="", help='specify the target of current training task') 32 | args = parser.parse_args() 33 | torch.cuda.set_device(args.local_rank) 34 | 35 | 36 | 37 | def init_seed(config): 38 | if config['manualSeed'] is None: 39 | config['manualSeed'] = random.randint(1, 10000) 40 | random.seed(config['manualSeed']) 41 | if config['cuda']: 42 | torch.manual_seed(config['manualSeed']) 43 | torch.cuda.manual_seed_all(config['manualSeed']) 44 | 45 | 46 | def prepare_training_data(config): 47 | # Only use the blending dataset class in training 48 | if True: 49 | train_set = DeepfakeAbstractBaseDataset(config, mode='train') 50 | print(' Train set : FF++ 23') 51 | 52 | if config['ddp']: 53 | sample = DistributedSampler(train_set) 54 | train_data_loader = \ 55 | torch.utils.data.DataLoader( 56 | dataset=train_set, 57 | batch_size=config['train_batchSize'], 58 | num_workers=int(config['workers']), 59 | collate_fn=train_set.collate_fn, 60 | sampler=sample 61 | ) 62 | else: 63 | train_data_loader = \ 64 | torch.utils.data.DataLoader( 65 | dataset=train_set, 66 | batch_size=config['train_batchSize'], 67 | shuffle=True, 68 | num_workers=int(config['workers']), 69 | collate_fn=train_set.collate_fn, 70 | ) 71 | 72 | return train_data_loader 73 | 74 | 75 | def prepare_testing_data(config): 76 | def get_test_data_loader(config, test_name): 77 | # update the config dictionary with the specific testing dataset 78 | config = config.copy() # create a copy of config to avoid altering the original one 79 | config['test_dataset'] = test_name # specify the current test dataset 80 | test_set = DeepfakeAbstractBaseDataset( 81 | config=config, 82 | mode='test', 83 | ) 84 | test_data_loader = \ 85 | torch.utils.data.DataLoader( 86 | dataset=test_set, 87 | batch_size=config['test_batchSize'], 88 | shuffle=False, 89 | num_workers=int(config['workers']), 90 | collate_fn=test_set.collate_fn, 91 | ) 92 | return test_data_loader 93 | 94 | test_data_loaders = {} 95 | for one_test_name in config['test_dataset']: 96 | test_data_loaders[one_test_name] = get_test_data_loader(config, one_test_name) 97 | return test_data_loaders 98 | 99 | 100 | def choose_optimizer(model, config): 101 | opt_name = config['optimizer']['type'] 102 | 103 | if opt_name == 'sgd': 104 | optimizer = optim.SGD( 105 | params=model.parameters(), 106 | lr=config['optimizer'][opt_name]['lr'], 107 | momentum=config['optimizer'][opt_name]['momentum'], 108 | weight_decay=config['optimizer'][opt_name]['weight_decay'] 109 | ) 110 | return optimizer 111 | elif opt_name == 'adam': 112 | optimizer = optim.Adam( 113 | params=model.parameters(), 114 | lr=config['optimizer'][opt_name]['lr'], 115 | weight_decay=config['optimizer'][opt_name]['weight_decay'], 116 | betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']), 117 | eps=config['optimizer'][opt_name]['eps'], 118 | amsgrad=config['optimizer'][opt_name]['amsgrad'], 119 | ) 120 | return optimizer 121 | else: 122 | raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer'])) 123 | 124 | 125 | 126 | def choose_scheduler(config, optimizer): 127 | if config['lr_scheduler'] is None: 128 | return None 129 | elif config['lr_scheduler'] == 'step': 130 | scheduler = optim.lr_scheduler.StepLR( 131 | optimizer, 132 | step_size=config['lr_step'], 133 | gamma=config['lr_gamma'], 134 | ) 135 | return scheduler 136 | elif config['lr_scheduler'] == 'cosine': 137 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 138 | optimizer, 139 | T_max=config['lr_T_max'], 140 | eta_min=config['lr_eta_min'], 141 | ) 142 | return scheduler 143 | else: 144 | raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler'])) 145 | 146 | 147 | def choose_metric(config): 148 | metric_scoring = config['metric_scoring'] 149 | if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: 150 | raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) 151 | return metric_scoring 152 | 153 | 154 | def main(): 155 | # parse options and load config 156 | 157 | with open(args.config_path, 'r') as f: 158 | config = yaml.safe_load(f) 159 | 160 | config['local_rank'] = args.local_rank 161 | if config['dry_run']: 162 | config['nEpochs'] = 0 163 | config['save_feat']=False 164 | # If arguments are provided, they will overwrite the yaml settings 165 | if args.train_dataset: 166 | config['train_dataset'] = args.train_dataset 167 | if args.test_dataset: 168 | config['test_dataset'] = args.test_dataset 169 | config['save_ckpt'] = args.save_ckpt 170 | config['save_feat'] = args.save_feat 171 | 172 | # create logger 173 | timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 174 | task_str = f"_{config['task_target']}" if config['task_target'] is not None else "" 175 | logger_path = os.path.join( 176 | config['log_dir'], 177 | config['model_name'] + task_str + '_' + timenow 178 | ) 179 | os.makedirs(logger_path, exist_ok=True) 180 | logger = create_logger(os.path.join(logger_path, 'training.log')) 181 | logger.info('Save log to {}'.format(logger_path)) 182 | config['ddp']= args.ddp 183 | # print configuration 184 | logger.info("--------------- Configuration ---------------") 185 | params_string = "Parameters: \n" 186 | for key, value in config.items(): 187 | params_string += "{}: {}".format(key, value) + "\n" 188 | logger.info(params_string) 189 | 190 | # init seed 191 | init_seed(config) 192 | 193 | # set cudnn benchmark if needed 194 | if config['cudnn']: 195 | cudnn.benchmark = True 196 | if config['ddp']: 197 | # dist.init_process_group(backend='gloo') 198 | dist.init_process_group( 199 | backend='nccl', 200 | timeout=timedelta(minutes=30) 201 | ) 202 | logger.addFilter(RankFilter(0)) 203 | # prepare the training data loader 204 | train_data_loader = prepare_training_data(config) 205 | 206 | # prepare the testing data loader 207 | test_data_loaders = prepare_testing_data(config) 208 | 209 | # prepare the model 210 | 211 | model = DS(clip_name=config['clip_model_name'], 212 | adapter_vit_name=config['vit_name'], 213 | num_quires=config['num_quires'], 214 | fusion_map=config['fusion_map'], 215 | mlp_dim=config['mlp_dim'], 216 | mlp_out_dim=config['mlp_out_dim'], 217 | head_num=config['head_num'], 218 | device=config['device']) 219 | 220 | 221 | # prepare the optimizer 222 | optimizer = choose_optimizer(model, config) 223 | 224 | # prepare the scheduler 225 | scheduler = choose_scheduler(config, optimizer) 226 | 227 | # prepare the metric 228 | metric_scoring = choose_metric(config) 229 | 230 | # prepare the trainer 231 | trainer = Trainer(config, model, optimizer, scheduler, logger, metric_scoring) 232 | # start training 233 | #trainer.load_ckpt('/data/cuixinjie/DsClip_L_V2/logs/ds_ _2024-10-08-20-28-02/test/avg/ckpt_best.pth') 234 | #trainer.load_ckpt('/data/cuixinjie/DsClip_L_V2/logs/ds_ _2024-09-25-14-26-04/test/avg/ckpt_best.pth') 235 | 236 | for epoch in range(config['start_epoch'], config['nEpochs'] + 1): 237 | trainer.model.epoch = epoch 238 | best_metric = trainer.train_epoch( 239 | epoch=epoch, 240 | train_data_loader=train_data_loader, 241 | test_data_loaders=test_data_loaders, 242 | ) 243 | if best_metric is not None: 244 | logger.info(f"===> Epoch[{epoch}] end with testing {metric_scoring}: {parse_metric_for_print(best_metric)}!") 245 | logger.info("Stop Training on best Testing metric {}".format(parse_metric_for_print(best_metric))) 246 | # update 247 | if 'svdd' in config['model_name']: 248 | model.update_R(epoch) 249 | if scheduler is not None: 250 | scheduler.step() 251 | 252 | # close the tensorboard writers 253 | for writer in trainer.writers.values(): 254 | writer.close() 255 | 256 | 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_file_path = os.path.abspath(__file__) 4 | parent_dir = os.path.dirname(os.path.dirname(current_file_path)) 5 | project_root_dir = os.path.dirname(parent_dir) 6 | sys.path.append(parent_dir) 7 | sys.path.append(project_root_dir) 8 | 9 | -------------------------------------------------------------------------------- /trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from copy import deepcopy 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class BaseTrainer(ABC): 7 | """ 8 | """ 9 | 10 | def __init__( 11 | self, 12 | config, 13 | model, 14 | optimizer, 15 | scheduler, 16 | writer, 17 | ): 18 | # check if all the necessary components are implemented 19 | if config is None or model is None or optimizer is None or scheduler is None or writer is None: 20 | raise NotImplementedError("config, model, optimizier, scheduler, and tensorboard writer must be implemented") 21 | 22 | self.config = config 23 | self.model = model 24 | self.optimizer = optimizer 25 | self.scheduler = scheduler 26 | self.writer = writer 27 | 28 | @abstractmethod 29 | def speed_up(self): 30 | pass 31 | 32 | @abstractmethod 33 | def setTrain(self): 34 | pass 35 | 36 | @abstractmethod 37 | def setEval(self): 38 | pass 39 | 40 | @abstractmethod 41 | def load_ckpt(self, model_path): 42 | pass 43 | 44 | @abstractmethod 45 | def save_ckpt(self, dataset, epoch, iters, best=False): 46 | pass 47 | 48 | @abstractmethod 49 | def inference(self, data_dict): 50 | pass 51 | -------------------------------------------------------------------------------- /trainer/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_file_path = os.path.abspath(__file__) 4 | parent_dir = os.path.dirname(os.path.dirname(current_file_path)) 5 | project_root_dir = os.path.dirname(parent_dir) 6 | sys.path.append(parent_dir) 7 | sys.path.append(project_root_dir) 8 | -------------------------------------------------------------------------------- /trainer/metrics/base_metrics_class.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | from collections import defaultdict 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def get_accracy(output, label): 9 | _, prediction = torch.max(output, 1) # argmax 10 | correct = (prediction == label).sum().item() 11 | accuracy = correct / prediction.size(0) 12 | return accuracy 13 | 14 | 15 | def get_prediction(output, label): 16 | prob = nn.functional.softmax(output, dim=1)[:, 1] 17 | prob = prob.view(prob.size(0), 1) 18 | label = label.view(label.size(0), 1) 19 | #print(prob.size(), label.size()) 20 | datas = torch.cat((prob, label.float()), dim=1) 21 | return datas 22 | 23 | 24 | def calculate_metrics_for_train(label, output): 25 | if output.size(1) == 2: 26 | prob = torch.softmax(output, dim=1)[:, 1] 27 | else: 28 | prob = output 29 | 30 | # Accuracy 31 | _, prediction = torch.max(output, 1) 32 | correct = (prediction == label).sum().item() 33 | accuracy = correct / prediction.size(0) 34 | 35 | # Average Precision 36 | y_true = label.cpu().detach().numpy() 37 | y_pred = prob.cpu().detach().numpy() 38 | #print(y_pred) 39 | ap = metrics.average_precision_score(y_true, y_pred) 40 | 41 | # AUC and EER 42 | try: 43 | fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(), 44 | prob.squeeze().cpu().numpy(), 45 | pos_label=1) 46 | except: 47 | # for the case when we only have one sample 48 | return None, None, accuracy, ap 49 | 50 | if np.isnan(fpr[0]) or np.isnan(tpr[0]): 51 | # for the case when all the samples within a batch is fake/real 52 | auc, eer = None, None 53 | else: 54 | auc = metrics.auc(fpr, tpr) 55 | fnr = 1 - tpr 56 | eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 57 | 58 | return auc, eer, accuracy, ap 59 | 60 | 61 | # ------------ compute average metrics of batches--------------------- 62 | class Metrics_batch(): 63 | def __init__(self): 64 | self.tprs = [] 65 | self.mean_fpr = np.linspace(0, 1, 100) 66 | self.aucs = [] 67 | self.eers = [] 68 | self.aps = [] 69 | 70 | self.correct = 0 71 | self.total = 0 72 | self.losses = [] 73 | 74 | def update(self, label, output): 75 | acc = self._update_acc(label, output) 76 | if output.size(1) == 2: 77 | prob = torch.softmax(output, dim=1)[:, 1] 78 | else: 79 | prob = output 80 | #label = 1-label 81 | #prob = torch.softmax(output, dim=1)[:, 1] 82 | auc, eer = self._update_auc(label, prob) 83 | ap = self._update_ap(label, prob) 84 | 85 | return acc, auc, eer, ap 86 | 87 | def _update_auc(self, lab, prob): 88 | fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(), 89 | prob.squeeze().cpu().numpy(), 90 | pos_label=1) 91 | if np.isnan(fpr[0]) or np.isnan(tpr[0]): 92 | return -1, -1 93 | 94 | auc = metrics.auc(fpr, tpr) 95 | interp_tpr = np.interp(self.mean_fpr, fpr, tpr) 96 | interp_tpr[0] = 0.0 97 | self.tprs.append(interp_tpr) 98 | self.aucs.append(auc) 99 | 100 | # return auc 101 | 102 | # EER 103 | fnr = 1 - tpr 104 | eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 105 | self.eers.append(eer) 106 | 107 | return auc, eer 108 | 109 | def _update_acc(self, lab, output): 110 | _, prediction = torch.max(output, 1) # argmax 111 | correct = (prediction == lab).sum().item() 112 | accuracy = correct / prediction.size(0) 113 | # self.accs.append(accuracy) 114 | self.correct = self.correct+correct 115 | self.total = self.total+lab.size(0) 116 | return accuracy 117 | 118 | def _update_ap(self, label, prob): 119 | y_true = label.cpu().detach().numpy() 120 | y_pred = prob.cpu().detach().numpy() 121 | ap = metrics.average_precision_score(y_true,y_pred) 122 | self.aps.append(ap) 123 | 124 | return np.mean(ap) 125 | 126 | def get_mean_metrics(self): 127 | mean_acc, std_acc = self.correct/self.total, 0 128 | mean_auc, std_auc = self._mean_auc() 129 | mean_err, std_err = np.mean(self.eers), np.std(self.eers) 130 | mean_ap, std_ap = np.mean(self.aps), np.std(self.aps) 131 | 132 | return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap} 133 | 134 | def _mean_auc(self): 135 | mean_tpr = np.mean(self.tprs, axis=0) 136 | mean_tpr[-1] = 1.0 137 | mean_auc = metrics.auc(self.mean_fpr, mean_tpr) 138 | std_auc = np.std(self.aucs) 139 | return mean_auc, std_auc 140 | 141 | def clear(self): 142 | self.tprs.clear() 143 | self.aucs.clear() 144 | # self.accs.clear() 145 | self.correct=0 146 | self.total=0 147 | self.eers.clear() 148 | self.aps.clear() 149 | self.losses.clear() 150 | 151 | 152 | # ------------ compute average metrics of all data --------------------- 153 | class Metrics_all(): 154 | def __init__(self): 155 | self.probs = [] 156 | self.labels = [] 157 | self.correct = 0 158 | self.total = 0 159 | 160 | def store(self, label, output): 161 | prob = torch.softmax(output, dim=1)[:, 1] 162 | _, prediction = torch.max(output, 1) # argmax 163 | correct = (prediction == label).sum().item() 164 | self.correct += correct 165 | self.total += label.size(0) 166 | self.labels.append(label.squeeze().cpu().numpy()) 167 | self.probs.append(prob.squeeze().cpu().numpy()) 168 | 169 | def get_metrics(self): 170 | y_pred = np.concatenate(self.probs) 171 | y_true = np.concatenate(self.labels) 172 | # auc 173 | fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1) 174 | auc = metrics.auc(fpr, tpr) 175 | # eer 176 | fnr = 1 - tpr 177 | eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 178 | # ap 179 | ap = metrics.average_precision_score(y_true,y_pred) 180 | # acc 181 | acc = self.correct / self.total 182 | return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap} 183 | 184 | def clear(self): 185 | self.probs.clear() 186 | self.labels.clear() 187 | self.correct = 0 188 | self.total = 0 189 | 190 | 191 | # only used to record a series of scalar value 192 | class Recorder: 193 | def __init__(self): 194 | self.sum = 0 195 | self.num = 0 196 | def update(self, item, num=1): 197 | if item is not None: 198 | self.sum += item * num 199 | self.num += num 200 | def average(self): 201 | if self.num == 0: 202 | return None 203 | return self.sum/self.num 204 | def clear(self): 205 | self.sum = 0 206 | self.num = 0 207 | -------------------------------------------------------------------------------- /trainer/metrics/registry.py: -------------------------------------------------------------------------------- 1 | class Registry(object): 2 | def __init__(self): 3 | self.data = {} 4 | 5 | def register_module(self, module_name=None): 6 | def _register(cls): 7 | name = module_name 8 | if module_name is None: 9 | name = cls.__name__ 10 | self.data[name] = cls 11 | return cls 12 | return _register 13 | 14 | def __getitem__(self, key): 15 | return self.data[key] 16 | 17 | BACKBONE = Registry() 18 | DETECTOR = Registry() 19 | TRAINER = Registry() 20 | LOSSFUNC = Registry() 21 | -------------------------------------------------------------------------------- /trainer/metrics/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import numpy as np 3 | 4 | 5 | def parse_metric_for_print(metric_dict): 6 | if metric_dict is None: 7 | return "\n" 8 | str = "\n" 9 | str += "================================ Each dataset best metric ================================ \n" 10 | for key, value in metric_dict.items(): 11 | if key != 'avg': 12 | str= str+ f"| {key}: " 13 | for k,v in value.items(): 14 | str = str + f" {k}={v} " 15 | str= str+ "| \n" 16 | else: 17 | str += "============================================================================================= \n" 18 | str += "================================== Average best metric ====================================== \n" 19 | avg_dict = value 20 | for avg_key, avg_value in avg_dict.items(): 21 | if avg_key == 'dataset_dict': 22 | for key,value in avg_value.items(): 23 | str = str + f"| {key}: {value} | \n" 24 | else: 25 | str = str + f"| avg {avg_key}: {avg_value} | \n" 26 | str += "=============================================================================================" 27 | return str 28 | 29 | 30 | def get_test_metrics(y_pred, y_true, img_names): 31 | def get_video_metrics(image, pred, label): 32 | result_dict = {} 33 | new_label = [] 34 | new_pred = [] 35 | # print(image[0]) 36 | # print(pred.shape) 37 | # print(label.shape) 38 | for item in np.transpose(np.stack((image, pred, label)), (1, 0)): 39 | 40 | s = item[0] 41 | if '\\' in s: 42 | parts = s.split('\\') 43 | else: 44 | parts = s.split('/') 45 | a = parts[-2] 46 | b = parts[-1] 47 | 48 | if a not in result_dict: 49 | result_dict[a] = [] 50 | 51 | result_dict[a].append(item) 52 | image_arr = list(result_dict.values()) 53 | 54 | for video in image_arr: 55 | pred_sum = 0 56 | label_sum = 0 57 | leng = 0 58 | for frame in video: 59 | pred_sum += float(frame[1]) 60 | label_sum += int(frame[2]) 61 | leng += 1 62 | new_pred.append(pred_sum / leng) 63 | new_label.append(int(label_sum / leng)) 64 | fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred) 65 | v_auc = metrics.auc(fpr, tpr) 66 | fnr = 1 - tpr 67 | v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 68 | return v_auc, v_eer 69 | 70 | 71 | y_pred = y_pred.squeeze() 72 | # auc 73 | fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) 74 | auc = metrics.auc(fpr, tpr) 75 | # eer 76 | fnr = 1 - tpr 77 | eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 78 | # ap 79 | ap = metrics.average_precision_score(y_true, y_pred) 80 | # acc 81 | prediction_class = (y_pred > 0.5).astype(int) 82 | correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item() 83 | acc = correct / len(prediction_class) 84 | if type(img_names[0]) is not list: 85 | # calculate video-level auc for the frame-level methods. 86 | v_auc, _ = get_video_metrics(img_names, y_pred, y_true) 87 | else: 88 | # video-level methods 89 | v_auc=auc 90 | 91 | return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'video_auc': v_auc, 'label': y_true} 92 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_file_path = os.path.abspath(__file__) 4 | parent_dir = os.path.dirname(os.path.dirname(current_file_path)) 5 | project_root_dir = os.path.dirname(parent_dir) 6 | sys.path.append(parent_dir) 7 | sys.path.append(project_root_dir) 8 | 9 | import pickle 10 | import datetime 11 | import numpy as np 12 | from collections import defaultdict 13 | from tqdm import tqdm 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter 16 | from metrics.base_metrics_class import Recorder 17 | from torch.optim.swa_utils import AveragedModel 18 | from torch import distributed as dist 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | from trainer.metrics.utils import get_test_metrics 21 | from torch.nn import DataParallel 22 | FFpp_pool = ['FaceForensics++', ] 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | 27 | class Trainer(object): 28 | def __init__( 29 | self, 30 | config, 31 | model, 32 | optimizer, 33 | scheduler, 34 | logger, 35 | metric_scoring='auc', 36 | time_now=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), 37 | swa_model=None 38 | ): 39 | # check if all the necessary components are implemented 40 | if config is None or model is None or optimizer is None or logger is None: 41 | raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented") 42 | 43 | self.config = config 44 | self.model = model 45 | self.optimizer = optimizer 46 | self.scheduler = scheduler 47 | self.swa_model = swa_model 48 | self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric 49 | self.logger = logger 50 | self.metric_scoring = metric_scoring 51 | # maintain the best metric of all epochs 52 | self.best_metrics_all_time = defaultdict( 53 | lambda: defaultdict(lambda: float('-inf') 54 | if self.metric_scoring != 'eer' else float('inf')) 55 | ) 56 | self.speed_up() # move model to GPU 57 | 58 | # get current time 59 | self.timenow = time_now 60 | # create directory path 61 | if 'task_target' not in config: 62 | self.log_dir = os.path.join( 63 | self.config['log_dir'], 64 | self.config['model_name'] + '_' + self.timenow 65 | ) 66 | else: 67 | task_str = f"_{config['task_target']}" if config['task_target'] is not None else "" 68 | self.log_dir = os.path.join( 69 | self.config['log_dir'], 70 | self.config['model_name'] + task_str + '_' + self.timenow 71 | ) 72 | os.makedirs(self.log_dir, exist_ok=True) 73 | 74 | def get_writer(self, phase, dataset_key, metric_key): 75 | writer_key = f"{phase}-{dataset_key}-{metric_key}" 76 | if writer_key not in self.writers: 77 | # update directory path 78 | writer_path = os.path.join( 79 | self.log_dir, 80 | phase, 81 | dataset_key, 82 | metric_key, 83 | "metric_board" 84 | ) 85 | os.makedirs(writer_path, exist_ok=True) 86 | # update writers dictionary 87 | self.writers[writer_key] = SummaryWriter(writer_path) 88 | return self.writers[writer_key] 89 | 90 | def speed_up(self): 91 | self.model.to(device) 92 | self.model.device = device 93 | #self.model = DataParallel(self.model,device_ids=[0,1]) 94 | 95 | if self.config['ddp'] == True: 96 | num_gpus = torch.cuda.device_count() 97 | print(f'avai gpus: {num_gpus}') 98 | # local_rank=[i for i in range(0,num_gpus)] 99 | self.model = DDP(self.model, device_ids=[self.config['local_rank']], find_unused_parameters=True, 100 | output_device=self.config['local_rank']) 101 | # self.optimizer = nn.DataParallel(self.optimizer, device_ids=[int(os.environ['LOCAL_RANK'])]) 102 | 103 | 104 | def DP_speed_up(self): 105 | self.model.device = device 106 | self.model.to(device) 107 | self.model = DataParallel(self.model,device_ids=[0,1]) 108 | 109 | 110 | def setTrain(self): 111 | self.model.train() 112 | self.train = True 113 | 114 | def setEval(self): 115 | self.model.eval() 116 | self.train = False 117 | 118 | def load_ckpt(self, model_path): 119 | if os.path.isfile(model_path): 120 | saved = torch.load(model_path, map_location='cpu') 121 | suffix = model_path.split('.')[-1] 122 | if suffix == 'p': 123 | self.model.load_state_dict(saved.state_dict()) 124 | else: 125 | self.model.load_state_dict(saved) 126 | self.logger.info('Model found in {}'.format(model_path)) 127 | else: 128 | raise NotImplementedError( 129 | "=> no model found at '{}'".format(model_path)) 130 | 131 | def save_ckpt(self, phase, dataset_key, ckpt_info=None): 132 | save_dir = os.path.join(self.log_dir, phase, dataset_key) 133 | os.makedirs(save_dir, exist_ok=True) 134 | ckpt_name = f"ckpt_best.pth" 135 | save_path = os.path.join(save_dir, ckpt_name) 136 | if self.config['ddp'] == True: 137 | torch.save(self.model.state_dict(), save_path) 138 | else: 139 | if 'svdd' in self.config['model_name']: 140 | torch.save({'R': self.model.R, 141 | 'c': self.model.c, 142 | 'state_dict': self.model.state_dict(), }, save_path) 143 | else: 144 | torch.save(self.model.state_dict(), save_path) 145 | self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}") 146 | 147 | def save_swa_ckpt(self): 148 | save_dir = self.log_dir 149 | os.makedirs(save_dir, exist_ok=True) 150 | ckpt_name = f"swa.pth" 151 | save_path = os.path.join(save_dir, ckpt_name) 152 | torch.save(self.swa_model.state_dict(), save_path) 153 | self.logger.info(f"SWA Checkpoint saved to {save_path}") 154 | 155 | def save_feat(self, phase, fea, dataset_key): 156 | save_dir = os.path.join(self.log_dir, phase, dataset_key) 157 | os.makedirs(save_dir, exist_ok=True) 158 | features = fea 159 | feat_name = f"feat_best.npy" 160 | save_path = os.path.join(save_dir, feat_name) 161 | np.save(save_path, features) 162 | self.logger.info(f"Feature saved to {save_path}") 163 | 164 | def save_data_dict(self, phase, data_dict, dataset_key): 165 | save_dir = os.path.join(self.log_dir, phase, dataset_key) 166 | os.makedirs(save_dir, exist_ok=True) 167 | file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle') 168 | with open(file_path, 'wb') as file: 169 | pickle.dump(data_dict, file) 170 | self.logger.info(f"data_dict saved to {file_path}") 171 | 172 | def save_metrics(self, phase, metric_one_dataset, dataset_key): 173 | save_dir = os.path.join(self.log_dir, phase, dataset_key) 174 | os.makedirs(save_dir, exist_ok=True) 175 | file_path = os.path.join(save_dir, 'metric_dict_best.pickle') 176 | with open(file_path, 'wb') as file: 177 | pickle.dump(metric_one_dataset, file) 178 | self.logger.info(f"Metrics saved to {file_path}") 179 | 180 | def train_step(self, data_dict): 181 | if self.config['optimizer']['type'] == 'sam': 182 | for i in range(2): 183 | predictions = self.model(data_dict) 184 | losses = self.model.get_losses(data_dict, predictions) 185 | if i == 0: 186 | pred_first = predictions 187 | losses_first = losses 188 | self.optimizer.zero_grad() 189 | losses['overall'].backward() 190 | if i == 0: 191 | self.optimizer.first_step(zero_grad=True) 192 | else: 193 | self.optimizer.second_step(zero_grad=True) 194 | return losses_first, pred_first 195 | else: 196 | self.optimizer.zero_grad() 197 | predictions = self.model(data_dict) 198 | if type(self.model) is DDP: 199 | losses = self.model.module.get_losses(data_dict, predictions) 200 | else: 201 | losses = self.model.get_losses(data_dict, predictions) 202 | #self.optimizer.zero_grad() 203 | losses['overall'].backward() 204 | self.optimizer.step() 205 | 206 | return losses, predictions 207 | 208 | def train_epoch( 209 | self, 210 | epoch, 211 | train_data_loader, 212 | test_data_loaders=None, 213 | ): 214 | 215 | self.logger.info("===> Epoch[{}] start!".format(epoch)) 216 | 217 | if epoch >= 3: 218 | times_per_epoch = 1 219 | else: 220 | times_per_epoch = 4 221 | 222 | 223 | test_step = len(train_data_loader) // times_per_epoch 224 | step_cnt = epoch * len(train_data_loader) 225 | 226 | # define training recorder 227 | train_recorder_loss = defaultdict(Recorder) 228 | train_recorder_metric = defaultdict(Recorder) 229 | 230 | for iteration, data_dict in tqdm(enumerate(train_data_loader), total=len(train_data_loader)): 231 | 232 | self.setTrain() 233 | for key in data_dict.keys(): 234 | if data_dict[key] != None and key != 'name': 235 | data_dict[key] = data_dict[key].to(device) 236 | 237 | losses, predictions = self.train_step(data_dict) 238 | if 'SWA' in self.config and self.config['SWA'] and epoch > self.config['swa_start']: 239 | self.swa_model.update_parameters(self.model) 240 | 241 | if type(self.model) is DDP: 242 | batch_metrics = self.model.module.get_train_metrics(data_dict, predictions) 243 | else: 244 | batch_metrics = self.model.get_train_metrics(data_dict, predictions) 245 | 246 | for name, value in batch_metrics.items(): 247 | train_recorder_metric[name].update(value) 248 | ## store loss 249 | for name, value in losses.items(): 250 | train_recorder_loss[name].update(value) 251 | 252 | # run tensorboard to visualize the training process 253 | if iteration % 300 == 0: 254 | if 'SWA' in self.config and self.config['SWA'] and (epoch > self.config['swa_start'] or self.config['dry_run']): 255 | self.scheduler.step() 256 | # info for loss 257 | loss_str = f"Iter: {step_cnt} " 258 | for k, v in train_recorder_loss.items(): 259 | v_avg = v.average() 260 | if v_avg == None: 261 | loss_str += f"training-loss, {k}: not calculated" 262 | continue 263 | loss_str += f"training-loss, {k}: {v_avg} " 264 | # tensorboard-1. loss 265 | writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) 266 | writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt) 267 | self.logger.info(loss_str) 268 | # info for metric 269 | metric_str = f"Iter: {step_cnt} " 270 | for k, v in train_recorder_metric.items(): 271 | v_avg = v.average() 272 | if v_avg == None: 273 | metric_str += f"training-metric, {k}: not calculated " 274 | continue 275 | metric_str += f"training-metric, {k}: {v_avg} " 276 | # tensorboard-2. metric 277 | writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) 278 | writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt) 279 | self.logger.info(metric_str) 280 | 281 | # clear recorder. 282 | # Note we only consider the current 300 samples for computing batch-level loss/metric 283 | for name, recorder in train_recorder_loss.items(): # clear loss recorder 284 | recorder.clear() 285 | for name, recorder in train_recorder_metric.items(): # clear metric recorder 286 | recorder.clear() 287 | 288 | # run test 289 | if (step_cnt + 1) % test_step == 0: 290 | if test_data_loaders is not None : 291 | self.logger.info("===> Test start!") 292 | test_best_metric = self.test_epoch( 293 | epoch, 294 | iteration, 295 | test_data_loaders, 296 | step_cnt, 297 | ) 298 | 299 | # total_end_time = time.time() 300 | # total_elapsed_time = total_end_time - total_start_time 301 | # print("总花费的时间: {:.2f} 秒".format(total_elapsed_time)) 302 | step_cnt += 1 303 | return test_best_metric 304 | 305 | def get_respect_acc(self, prob, label): 306 | pred = np.where(prob > 0.5, 1, 0) 307 | judge = (pred == label) 308 | zero_num = len(label) - np.count_nonzero(label) 309 | acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:]) 310 | acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num]) 311 | return acc_real, acc_fake 312 | 313 | def test_one_dataset(self, data_loader): 314 | # define test recorder 315 | test_recorder_loss = defaultdict(Recorder) 316 | prediction_lists = [] 317 | feature_lists = [] 318 | label_lists = [] 319 | for i, data_dict in tqdm(enumerate(data_loader), total=len(data_loader)): 320 | # get data 321 | if 'label_spe' in data_dict: 322 | data_dict.pop('label_spe') # remove the specific label 323 | data_dict['label'] = torch.where(data_dict['label'] != 0, 1, 0) # fix the label to 0 and 1 only 324 | # move data to GPU elegantly 325 | for key in data_dict.keys(): 326 | if data_dict[key] != None: 327 | data_dict[key] = data_dict[key].cuda() 328 | # model forward without considering gradient computation 329 | predictions = self.inference(data_dict) 330 | label_lists += list(data_dict['label'].cpu().detach().numpy()) 331 | prediction_lists += list(predictions['prob'].cpu().detach().numpy()) 332 | #feature_lists += list(predictions['feat'].cpu().detach().numpy()) 333 | if type(self.model) is not AveragedModel: 334 | # compute all losses for each batch data 335 | if type(self.model) is DDP: 336 | losses = self.model.module.get_losses(data_dict, predictions) 337 | else: 338 | losses = self.model.get_losses(data_dict, predictions) 339 | 340 | # store data by recorder 341 | for name, value in losses.items(): 342 | test_recorder_loss[name].update(value) 343 | 344 | #return test_recorder_loss, np.array(prediction_lists), np.array(label_lists), np.array(feature_lists) 345 | return test_recorder_loss, np.array(prediction_lists), np.array(label_lists) 346 | 347 | def save_best(self, epoch, iteration, step, losses_one_dataset_recorder, key, metric_one_dataset): 348 | best_metric = self.best_metrics_all_time[key].get(self.metric_scoring, 349 | float('-inf') if self.metric_scoring != 'eer' else float( 350 | 'inf')) 351 | # Check if the current score is an improvement 352 | improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else ( 353 | metric_one_dataset[self.metric_scoring] < best_metric) 354 | if improved: 355 | # Update the best metric 356 | self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring] 357 | if key == 'avg': 358 | self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict'] 359 | # Save checkpoint, feature, and metrics if specified in config 360 | if self.config['save_ckpt'] and key not in FFpp_pool: 361 | self.save_ckpt('test', key, f"{epoch}+{iteration}") 362 | self.save_metrics('test', metric_one_dataset, key) 363 | if losses_one_dataset_recorder is not None: 364 | # info for each dataset 365 | loss_str = f"dataset: {key} step: {step} " 366 | for k, v in losses_one_dataset_recorder.items(): 367 | writer = self.get_writer('test', key, k) 368 | v_avg = v.average() 369 | if v_avg == None: 370 | print(f'{k} is not calculated') 371 | continue 372 | # tensorboard-1. loss 373 | writer.add_scalar(f'test_losses/{k}', v_avg, global_step=step) 374 | loss_str += f"testing-loss, {k}: {v_avg} " 375 | self.logger.info(loss_str) 376 | # tqdm.write(loss_str) 377 | metric_str = f"dataset: {key} step: {step} " 378 | for k, v in metric_one_dataset.items(): 379 | if k == 'pred' or k == 'label' or k == 'dataset_dict': 380 | continue 381 | metric_str += f"testing-metric, {k}: {v} " 382 | # tensorboard-2. metric 383 | writer = self.get_writer('test', key, k) 384 | writer.add_scalar(f'test_metrics/{k}', v, global_step=step) 385 | if 'pred' in metric_one_dataset: 386 | acc_real, acc_fake = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label']) 387 | metric_str += f'testing-metric, acc_real:{acc_real}; acc_fake:{acc_fake}' 388 | writer.add_scalar(f'test_metrics/acc_real', acc_real, global_step=step) 389 | writer.add_scalar(f'test_metrics/acc_fake', acc_fake, global_step=step) 390 | self.logger.info(metric_str) 391 | 392 | def test_epoch(self, epoch, iteration, test_data_loaders, step): 393 | # set model to eval mode 394 | self.setEval() 395 | 396 | # define test recorder 397 | losses_all_datasets = {} 398 | metrics_all_datasets = {} 399 | best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric 400 | avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0, 'video_auc': 0, 'dataset_dict': {}} 401 | # testing for all test data 402 | keys = test_data_loaders.keys() 403 | for key in keys: 404 | # save the testing data_dict 405 | data_dict = test_data_loaders[key].dataset.data_dict 406 | self.save_data_dict('test', data_dict, key) 407 | 408 | # compute loss for each dataset 409 | #losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.test_one_dataset( 410 | # test_data_loaders[key]) 411 | losses_one_dataset_recorder, predictions_nps, label_nps = self.test_one_dataset(test_data_loaders[key]) 412 | # print(f'stack len:{predictions_nps.shape};{label_nps.shape};{len(data_dict["image"])}') 413 | losses_all_datasets[key] = losses_one_dataset_recorder 414 | metric_one_dataset = get_test_metrics(y_pred=predictions_nps, y_true=label_nps, 415 | img_names=data_dict['image']) 416 | for metric_name, value in metric_one_dataset.items(): 417 | if metric_name in avg_metric: 418 | avg_metric[metric_name] += value 419 | avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring] 420 | if type(self.model) is AveragedModel: 421 | metric_str = f"Iter Final for SWA: " 422 | for k, v in metric_one_dataset.items(): 423 | metric_str += f"testing-metric, {k}: {v} " 424 | self.logger.info(metric_str) 425 | continue 426 | self.save_best(epoch, iteration, step, losses_one_dataset_recorder, key, metric_one_dataset) 427 | 428 | if len(keys) > 0 and self.config.get('save_avg', False): 429 | # calculate avg value 430 | for key in avg_metric: 431 | if key != 'dataset_dict': 432 | avg_metric[key] /= len(keys) 433 | self.save_best(epoch, iteration, step, None, 'avg', avg_metric) 434 | 435 | self.logger.info('===> Test Done!') 436 | return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt 437 | 438 | @torch.no_grad() 439 | def inference(self, data_dict): 440 | predictions = self.model(data_dict, inference=True) 441 | return predictions 442 | --------------------------------------------------------------------------------