├── LICENSE ├── README.md ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── datasets.cpython-38.pyc │ └── datasets.cpython-39.pyc └── datasets.py ├── dataset_paths.py ├── models ├── __init__.py ├── clip │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── clip.cpython-310.pyc │ │ ├── clip.cpython-38.pyc │ │ ├── clip.cpython-39.pyc │ │ ├── model.cpython-310.pyc │ │ ├── model.cpython-38.pyc │ │ ├── model.cpython-39.pyc │ │ ├── simple_tokenizer.cpython-310.pyc │ │ ├── simple_tokenizer.cpython-38.pyc │ │ └── simple_tokenizer.cpython-39.pyc │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── clip_models.py ├── imagenet_models.py ├── resnet.py ├── vgg.py ├── vision_transformer.py ├── vision_transformer_misc.py └── vision_transformer_utils.py ├── networks ├── __init__.py ├── base_model.py ├── lpf.py ├── resnet_lpf.py └── trainer.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── pretrained_weights └── fc_weights.pth ├── resources └── teaser.png ├── test.sh ├── train.py └── validate.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Wisconsin AI and Vision Lab (WAIV) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Detecting fake images 2 | 3 | **Towards Universal Fake Image Detectors that Generalize Across Generative Models**
4 | [Utkarsh Ojha*](https://utkarshojha.github.io/), [Yuheng Li*](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)
5 | (*Equal contribution)
6 | CVPR 2023 7 | 8 | [[Project Page](https://utkarshojha.github.io/universal-fake-detection/)] [[Paper](https://arxiv.org/abs/2302.10174)] 9 | 10 |

11 | >
12 | Using images from one type of generative model (e.g., GAN), detect fake images from other breeds (e.g., Diffusion models) 13 |

14 | 15 | ## Contents 16 | 17 | - [Setup](#setup) 18 | - [Pretrained model](#weights) 19 | - [Data](#data) 20 | - [Evaluation](#evaluation) 21 | - [Training](#training) 22 | 23 | 24 | ## Setup 25 | 26 | 1. Clone this repository 27 | ```bash 28 | git clone https://github.com/Yuheng-Li/UniversalFakeDetect 29 | cd UniversalFakeDetect 30 | ``` 31 | 32 | 2. Install the necessary libraries 33 | ```bash 34 | pip install torch torchvision 35 | ``` 36 | 37 | ## Data 38 | 39 | - Of the 19 models studied overall (Table 1/2 in the main paper), 11 are taken from a [previous work](https://arxiv.org/abs/1912.11035). Download the test set, i.e., real/fake images for those 11 models given by the authors from [here](https://drive.google.com/file/d/1z_fD3UKgWQyOTZIBbYSaQ-hz4AzUrLC1/view) (dataset size ~19GB). 40 | - Download the file and unzip it in `datasets/test`. You could also use the bash scripts provided by the authors, as described [here](https://github.com/PeterWang512/CNNDetection#download-the-dataset) in their code repository. 41 | - This should create a directory structure as follows: 42 | ``` 43 | 44 | datasets 45 | └── test 46 | ├── progan 47 | │── cyclegan 48 | │── biggan 49 | │ . 50 | │ . 51 | 52 | ``` 53 | - Each directory (e.g., progan) will contain real/fake images under `0_real` and `1_fake` folders respectively. 54 | - Dataset for the diffusion models (e.g., LDM/Glide) can be found [here](https://drive.google.com/file/d/1FXlGIRh_Ud3cScMgSVDbEWmPDmjcrm1t/view?usp=drive_link). Note that in the paper (Table 2/3), we had reported the results over 10k randomly sampled images. Since providing that many images for all the domains will take up too much space, we are only releasing 1k images for each domain; i.e., 1k images fake images and 1k real images for each domain (e.g., LDM-200). 55 | - Download and unzip the file into `./diffusion_datasets` directory. 56 | 57 | 58 | ## Evaluation 59 | 60 | - You can evaluate the model on all the dataset at once by running: 61 | ```bash 62 | python validate.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --result_folder=clip_vitl14 63 | ``` 64 | 65 | - You can also evaluate the model on one generative model by specifying the paths of real and fake datasets 66 | ```bash 67 | python validate.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --result_folder=clip_vitl14 --real_path datasets/test/progan/0_real --fake_path datasets/test/progan/1_fake 68 | ``` 69 | 70 | Note that if no arguments are provided for `real_path` and `fake_path`, the script will perform the evaluation on all the domains specified in `dataset_paths.py`. 71 | 72 | - The results will be stored in `results/` in two files: `ap.txt` stores the Average Prevision for each of the test domains, and `acc.txt` stores the accuracy (with 0.5 as the threshold) for the same domains. 73 | 74 | ## Training 75 | 76 | - Our main model is trained on the same dataset used by the authors of [this work](https://arxiv.org/abs/1912.11035). Download the official training dataset provided [here](https://drive.google.com/file/d/1iVNBV0glknyTYGA9bCxT_d0CVTOgGcKh/view) (dataset size ~ 72GB). 77 | 78 | - Download and unzip the dataset in `datasets/train` directory. The overall structure should look like the following: 79 | ``` 80 | datasets 81 | └── train 82 | └── progan 83 | ├── airplane 84 | │── bird 85 | │── boat 86 | │ . 87 | │ . 88 | ``` 89 | - A total of 20 different object categories, with each folder containing the corresponding real and fake images in `0_real` and `1_fake` folders. 90 | - The model can then be trained with the following command: 91 | ```bash 92 | python train.py --name=clip_vitl14 --wang2020_data_path=datasets/ --data_mode=wang2020 --arch=CLIP:ViT-L/14 --fix_backbone 93 | ``` 94 | - **Important**: do not forget to use the `--fix_backbone` argument during training, which makes sure that the only the linear layer's parameters will be trained. 95 | 96 | ## Acknowledgement 97 | 98 | We would like to thank [Sheng-Yu Wang](https://github.com/PeterWang512) for releasing the real/fake images from different generative models. Our training pipeline is also inspired by his [open-source code](https://github.com/PeterWang512/CNNDetection). We would also like to thank [CompVis](https://github.com/CompVis) for releasing the pre-trained [LDMs](https://github.com/CompVis/latent-diffusion) and [LAION](https://laion.ai/) for open-sourcing [LAION-400M dataset](https://laion.ai/blog/laion-400-open-dataset/). 99 | 100 | ## Citation 101 | 102 | If you find our work helpful in your research, please cite it using the following: 103 | ```bibtex 104 | @inproceedings{ojha2023fakedetect, 105 | title={Towards Universal Fake Image Detectors that Generalize Across Generative Models}, 106 | author={Ojha, Utkarsh and Li, Yuheng and Lee, Yong Jae}, 107 | booktitle={CVPR}, 108 | year={2023}, 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data.sampler import WeightedRandomSampler 4 | 5 | from .datasets import RealFakeDataset 6 | 7 | 8 | 9 | def get_bal_sampler(dataset): 10 | targets = [] 11 | for d in dataset.datasets: 12 | targets.extend(d.targets) 13 | 14 | ratio = np.bincount(targets) 15 | w = 1. / torch.tensor(ratio, dtype=torch.float) 16 | sample_weights = w[targets] 17 | sampler = WeightedRandomSampler(weights=sample_weights, 18 | num_samples=len(sample_weights)) 19 | return sampler 20 | 21 | 22 | def create_dataloader(opt, preprocess=None): 23 | shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False 24 | dataset = RealFakeDataset(opt) 25 | if '2b' in opt.arch: 26 | dataset.transform = preprocess 27 | sampler = get_bal_sampler(dataset) if opt.class_bal else None 28 | 29 | data_loader = torch.utils.data.DataLoader(dataset, 30 | batch_size=opt.batch_size, 31 | shuffle=shuffle, 32 | sampler=sampler, 33 | num_workers=int(opt.num_threads)) 34 | return data_loader 35 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/datasets.cpython-39.pyc -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torchvision.datasets as datasets 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as TF 6 | from torch.utils.data import Dataset 7 | from random import random, choice, shuffle 8 | from io import BytesIO 9 | from PIL import Image 10 | from PIL import ImageFile 11 | from scipy.ndimage.filters import gaussian_filter 12 | import pickle 13 | import os 14 | from skimage.io import imread 15 | from copy import deepcopy 16 | 17 | ImageFile.LOAD_TRUNCATED_IMAGES = True 18 | 19 | 20 | MEAN = { 21 | "imagenet":[0.485, 0.456, 0.406], 22 | "clip":[0.48145466, 0.4578275, 0.40821073] 23 | } 24 | 25 | STD = { 26 | "imagenet":[0.229, 0.224, 0.225], 27 | "clip":[0.26862954, 0.26130258, 0.27577711] 28 | } 29 | 30 | 31 | 32 | 33 | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg"]): 34 | out = [] 35 | for r, d, f in os.walk(rootdir): 36 | for file in f: 37 | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): 38 | out.append(os.path.join(r, file)) 39 | return out 40 | 41 | 42 | def get_list(path, must_contain=''): 43 | if ".pickle" in path: 44 | with open(path, 'rb') as f: 45 | image_list = pickle.load(f) 46 | image_list = [ item for item in image_list if must_contain in item ] 47 | else: 48 | image_list = recursively_read(path, must_contain) 49 | return image_list 50 | 51 | 52 | 53 | 54 | class RealFakeDataset(Dataset): 55 | def __init__(self, opt): 56 | assert opt.data_label in ["train", "val"] 57 | #assert opt.data_mode in ["ours", "wang2020", "ours_wang2020"] 58 | self.data_label = opt.data_label 59 | if opt.data_mode == 'ours': 60 | pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle" 61 | real_list = get_list( os.path.join(opt.real_list_path, pickle_name) ) 62 | fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) ) 63 | elif opt.data_mode == 'wang2020': 64 | temp = 'train/progan' if opt.data_label == 'train' else 'test/progan' 65 | real_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' ) 66 | fake_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' ) 67 | elif opt.data_mode == 'ours_wang2020': 68 | pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle" 69 | real_list = get_list( os.path.join(opt.real_list_path, pickle_name) ) 70 | fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) ) 71 | temp = 'train/progan' if opt.data_label == 'train' else 'test/progan' 72 | real_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' ) 73 | fake_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' ) 74 | 75 | 76 | 77 | # setting the labels for the dataset 78 | self.labels_dict = {} 79 | for i in real_list: 80 | self.labels_dict[i] = 0 81 | for i in fake_list: 82 | self.labels_dict[i] = 1 83 | 84 | self.total_list = real_list + fake_list 85 | shuffle(self.total_list) 86 | if opt.isTrain: 87 | crop_func = transforms.RandomCrop(opt.cropSize) 88 | elif opt.no_crop: 89 | crop_func = transforms.Lambda(lambda img: img) 90 | else: 91 | crop_func = transforms.CenterCrop(opt.cropSize) 92 | 93 | if opt.isTrain and not opt.no_flip: 94 | flip_func = transforms.RandomHorizontalFlip() 95 | else: 96 | flip_func = transforms.Lambda(lambda img: img) 97 | if not opt.isTrain and opt.no_resize: 98 | rz_func = transforms.Lambda(lambda img: img) 99 | else: 100 | rz_func = transforms.Lambda(lambda img: custom_resize(img, opt)) 101 | 102 | 103 | stat_from = "imagenet" if opt.arch.lower().startswith("imagenet") else "clip" 104 | 105 | print("mean and std stats are from: ", stat_from) 106 | if '2b' not in opt.arch: 107 | print ("using Official CLIP's normalization") 108 | self.transform = transforms.Compose([ 109 | rz_func, 110 | transforms.Lambda(lambda img: data_augment(img, opt)), 111 | crop_func, 112 | flip_func, 113 | transforms.ToTensor(), 114 | transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), 115 | ]) 116 | else: 117 | print ("Using CLIP 2B transform") 118 | self.transform = None # will be initialized in trainer.py 119 | 120 | 121 | def __len__(self): 122 | return len(self.total_list) 123 | 124 | 125 | def __getitem__(self, idx): 126 | img_path = self.total_list[idx] 127 | label = self.labels_dict[img_path] 128 | img = Image.open(img_path).convert("RGB") 129 | img = self.transform(img) 130 | return img, label 131 | 132 | 133 | def data_augment(img, opt): 134 | img = np.array(img) 135 | if img.ndim == 2: 136 | img = np.expand_dims(img, axis=2) 137 | img = np.repeat(img, 3, axis=2) 138 | 139 | if random() < opt.blur_prob: 140 | sig = sample_continuous(opt.blur_sig) 141 | gaussian_blur(img, sig) 142 | 143 | if random() < opt.jpg_prob: 144 | method = sample_discrete(opt.jpg_method) 145 | qual = sample_discrete(opt.jpg_qual) 146 | img = jpeg_from_key(img, qual, method) 147 | 148 | return Image.fromarray(img) 149 | 150 | 151 | def sample_continuous(s): 152 | if len(s) == 1: 153 | return s[0] 154 | if len(s) == 2: 155 | rg = s[1] - s[0] 156 | return random() * rg + s[0] 157 | raise ValueError("Length of iterable s should be 1 or 2.") 158 | 159 | 160 | def sample_discrete(s): 161 | if len(s) == 1: 162 | return s[0] 163 | return choice(s) 164 | 165 | 166 | def gaussian_blur(img, sigma): 167 | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) 168 | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) 169 | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) 170 | 171 | 172 | def cv2_jpg(img, compress_val): 173 | img_cv2 = img[:,:,::-1] 174 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] 175 | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) 176 | decimg = cv2.imdecode(encimg, 1) 177 | return decimg[:,:,::-1] 178 | 179 | 180 | def pil_jpg(img, compress_val): 181 | out = BytesIO() 182 | img = Image.fromarray(img) 183 | img.save(out, format='jpeg', quality=compress_val) 184 | img = Image.open(out) 185 | # load from memory before ByteIO closes 186 | img = np.array(img) 187 | out.close() 188 | return img 189 | 190 | 191 | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} 192 | def jpeg_from_key(img, compress_val, key): 193 | method = jpeg_dict[key] 194 | return method(img, compress_val) 195 | 196 | 197 | rz_dict = {'bilinear': Image.BILINEAR, 198 | 'bicubic': Image.BICUBIC, 199 | 'lanczos': Image.LANCZOS, 200 | 'nearest': Image.NEAREST} 201 | def custom_resize(img, opt): 202 | interp = sample_discrete(opt.rz_interp) 203 | return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp]) 204 | -------------------------------------------------------------------------------- /dataset_paths.py: -------------------------------------------------------------------------------- 1 | DATASET_PATHS = [ 2 | 3 | 4 | dict( 5 | real_path='../FAKE_IMAGES/CNN/test/progan', 6 | fake_path='../FAKE_IMAGES/CNN/test/progan', 7 | data_mode='wang2020', 8 | key='progan' 9 | ), 10 | 11 | dict( 12 | real_path='../FAKE_IMAGES/CNN/test/cyclegan', 13 | fake_path='../FAKE_IMAGES/CNN/test/cyclegan', 14 | data_mode='wang2020', 15 | key='cyclegan' 16 | ), 17 | 18 | dict( 19 | real_path='../FAKE_IMAGES/CNN/test/biggan/', # Imagenet 20 | fake_path='../FAKE_IMAGES/CNN/test/biggan/', 21 | data_mode='wang2020', 22 | key='biggan' 23 | ), 24 | 25 | 26 | dict( 27 | real_path='../FAKE_IMAGES/CNN/test/stylegan', 28 | fake_path='../FAKE_IMAGES/CNN/test/stylegan', 29 | data_mode='wang2020', 30 | key='stylegan' 31 | ), 32 | 33 | 34 | dict( 35 | real_path='../FAKE_IMAGES/CNN/test/gaugan', # It is COCO 36 | fake_path='../FAKE_IMAGES/CNN/test/gaugan', 37 | data_mode='wang2020', 38 | key='gaugan' 39 | ), 40 | 41 | 42 | dict( 43 | real_path='../FAKE_IMAGES/CNN/test/stargan', 44 | fake_path='../FAKE_IMAGES/CNN/test/stargan', 45 | data_mode='wang2020', 46 | key='stargan' 47 | ), 48 | 49 | 50 | dict( 51 | real_path='../FAKE_IMAGES/CNN/test/deepfake', 52 | fake_path='../FAKE_IMAGES/CNN/test/deepfake', 53 | data_mode='wang2020', 54 | key='deepfake' 55 | ), 56 | 57 | 58 | dict( 59 | real_path='../FAKE_IMAGES/CNN/test/seeingdark', 60 | fake_path='../FAKE_IMAGES/CNN/test/seeingdark', 61 | data_mode='wang2020', 62 | key='sitd' 63 | ), 64 | 65 | 66 | dict( 67 | real_path='../FAKE_IMAGES/CNN/test/san', 68 | fake_path='../FAKE_IMAGES/CNN/test/san', 69 | data_mode='wang2020', 70 | key='san' 71 | ), 72 | 73 | 74 | dict( 75 | real_path='../FAKE_IMAGES/CNN/test/crn', # Images from some video games 76 | fake_path='../FAKE_IMAGES/CNN/test/crn', 77 | data_mode='wang2020', 78 | key='crn' 79 | ), 80 | 81 | 82 | dict( 83 | real_path='../FAKE_IMAGES/CNN/test/imle', # Images from some video games 84 | fake_path='../FAKE_IMAGES/CNN/test/imle', 85 | data_mode='wang2020', 86 | key='imle' 87 | ), 88 | 89 | 90 | dict( 91 | real_path='./diffusion_datasets/imagenet', 92 | fake_path='./diffusion_datasets/guided', 93 | data_mode='wang2020', 94 | key='guided' 95 | ), 96 | 97 | 98 | dict( 99 | real_path='./diffusion_datasets/laion', 100 | fake_path='./diffusion_datasets/ldm_200', 101 | data_mode='wang2020', 102 | key='ldm_200' 103 | ), 104 | 105 | dict( 106 | real_path='./diffusion_datasets/laion', 107 | fake_path='./diffusion_datasets/ldm_200_cfg', 108 | data_mode='wang2020', 109 | key='ldm_200_cfg' 110 | ), 111 | 112 | dict( 113 | real_path='./diffusion_datasets/laion', 114 | fake_path='./diffusion_datasets/ldm_100', 115 | data_mode='wang2020', 116 | key='ldm_100' 117 | ), 118 | 119 | 120 | dict( 121 | real_path='./diffusion_datasets/laion', 122 | fake_path='./diffusion_datasets/glide_100_27', 123 | data_mode='wang2020', 124 | key='glide_100_27' 125 | ), 126 | 127 | 128 | dict( 129 | real_path='./diffusion_datasets/laion', 130 | fake_path='./diffusion_datasets/glide_50_27', 131 | data_mode='wang2020', 132 | key='glide_50_27' 133 | ), 134 | 135 | 136 | dict( 137 | real_path='./diffusion_datasets/laion', 138 | fake_path='./diffusion_datasets/glide_100_10', 139 | data_mode='wang2020', 140 | key='glide_100_10' 141 | ), 142 | 143 | 144 | dict( 145 | real_path='./diffusion_datasets/laion', 146 | fake_path='./diffusion_datasets/dalle', 147 | data_mode='wang2020', 148 | key='dalle' 149 | ), 150 | 151 | 152 | 153 | ] 154 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_models import CLIPModel 2 | from .imagenet_models import ImagenetModel 3 | 4 | 5 | VALID_NAMES = [ 6 | 'Imagenet:resnet18', 7 | 'Imagenet:resnet34', 8 | 'Imagenet:resnet50', 9 | 'Imagenet:resnet101', 10 | 'Imagenet:resnet152', 11 | 'Imagenet:vgg11', 12 | 'Imagenet:vgg19', 13 | 'Imagenet:swin-b', 14 | 'Imagenet:swin-s', 15 | 'Imagenet:swin-t', 16 | 'Imagenet:vit_b_16', 17 | 'Imagenet:vit_b_32', 18 | 'Imagenet:vit_l_16', 19 | 'Imagenet:vit_l_32', 20 | 21 | 'CLIP:RN50', 22 | 'CLIP:RN101', 23 | 'CLIP:RN50x4', 24 | 'CLIP:RN50x16', 25 | 'CLIP:RN50x64', 26 | 'CLIP:ViT-B/32', 27 | 'CLIP:ViT-B/16', 28 | 'CLIP:ViT-L/14', 29 | 'CLIP:ViT-L/14@336px', 30 | ] 31 | 32 | 33 | 34 | 35 | 36 | def get_model(name): 37 | assert name in VALID_NAMES 38 | if name.startswith("Imagenet:"): 39 | return ImagenetModel(name[9:]) 40 | elif name.startswith("CLIP:"): 41 | return CLIPModel(name[5:]) 42 | else: 43 | assert False 44 | -------------------------------------------------------------------------------- /models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /models/clip/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /models/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | out = {} 204 | for idx, layer in enumerate(self.resblocks.children()): 205 | x = layer(x) 206 | out['layer'+str(idx)] = x[0] # shape:LND. choose cls token feature 207 | return out, x 208 | 209 | # return self.resblocks(x) # This is the original code 210 | 211 | 212 | class VisionTransformer(nn.Module): 213 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 214 | super().__init__() 215 | self.input_resolution = input_resolution 216 | self.output_dim = output_dim 217 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 218 | 219 | scale = width ** -0.5 220 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 221 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 222 | self.ln_pre = LayerNorm(width) 223 | 224 | self.transformer = Transformer(width, layers, heads) 225 | 226 | self.ln_post = LayerNorm(width) 227 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 228 | 229 | 230 | 231 | def forward(self, x: torch.Tensor): 232 | x = self.conv1(x) # shape = [*, width, grid, grid] 233 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 234 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 235 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 236 | x = x + self.positional_embedding.to(x.dtype) 237 | x = self.ln_pre(x) 238 | 239 | x = x.permute(1, 0, 2) # NLD -> LND 240 | out, x = self.transformer(x) 241 | x = x.permute(1, 0, 2) # LND -> NLD 242 | 243 | x = self.ln_post(x[:, 0, :]) 244 | 245 | 246 | out['before_projection'] = x 247 | 248 | if self.proj is not None: 249 | x = x @ self.proj 250 | out['after_projection'] = x 251 | 252 | # Return both intermediate features and final clip feature 253 | # return out 254 | 255 | # This only returns CLIP features 256 | return x 257 | 258 | 259 | class CLIP(nn.Module): 260 | def __init__(self, 261 | embed_dim: int, 262 | # vision 263 | image_resolution: int, 264 | vision_layers: Union[Tuple[int, int, int, int], int], 265 | vision_width: int, 266 | vision_patch_size: int, 267 | # text 268 | context_length: int, 269 | vocab_size: int, 270 | transformer_width: int, 271 | transformer_heads: int, 272 | transformer_layers: int 273 | ): 274 | super().__init__() 275 | 276 | self.context_length = context_length 277 | 278 | if isinstance(vision_layers, (tuple, list)): 279 | vision_heads = vision_width * 32 // 64 280 | self.visual = ModifiedResNet( 281 | layers=vision_layers, 282 | output_dim=embed_dim, 283 | heads=vision_heads, 284 | input_resolution=image_resolution, 285 | width=vision_width 286 | ) 287 | else: 288 | vision_heads = vision_width // 64 289 | self.visual = VisionTransformer( 290 | input_resolution=image_resolution, 291 | patch_size=vision_patch_size, 292 | width=vision_width, 293 | layers=vision_layers, 294 | heads=vision_heads, 295 | output_dim=embed_dim 296 | ) 297 | 298 | self.transformer = Transformer( 299 | width=transformer_width, 300 | layers=transformer_layers, 301 | heads=transformer_heads, 302 | attn_mask=self.build_attention_mask() 303 | ) 304 | 305 | self.vocab_size = vocab_size 306 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 307 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 308 | self.ln_final = LayerNorm(transformer_width) 309 | 310 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 311 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 312 | 313 | self.initialize_parameters() 314 | 315 | def initialize_parameters(self): 316 | nn.init.normal_(self.token_embedding.weight, std=0.02) 317 | nn.init.normal_(self.positional_embedding, std=0.01) 318 | 319 | if isinstance(self.visual, ModifiedResNet): 320 | if self.visual.attnpool is not None: 321 | std = self.visual.attnpool.c_proj.in_features ** -0.5 322 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 323 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 324 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 325 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 326 | 327 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 328 | for name, param in resnet_block.named_parameters(): 329 | if name.endswith("bn3.weight"): 330 | nn.init.zeros_(param) 331 | 332 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 333 | attn_std = self.transformer.width ** -0.5 334 | fc_std = (2 * self.transformer.width) ** -0.5 335 | for block in self.transformer.resblocks: 336 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 337 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 338 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 339 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 340 | 341 | if self.text_projection is not None: 342 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 343 | 344 | def build_attention_mask(self): 345 | # lazily create causal attention mask, with full attention between the vision tokens 346 | # pytorch uses additive attention mask; fill with -inf 347 | mask = torch.empty(self.context_length, self.context_length) 348 | mask.fill_(float("-inf")) 349 | mask.triu_(1) # zero out the lower diagonal 350 | return mask 351 | 352 | @property 353 | def dtype(self): 354 | return self.visual.conv1.weight.dtype 355 | 356 | def encode_image(self, image): 357 | return self.visual(image.type(self.dtype)) 358 | 359 | def encode_text(self, text): 360 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 361 | 362 | x = x + self.positional_embedding.type(self.dtype) 363 | x = x.permute(1, 0, 2) # NLD -> LND 364 | x = self.transformer(x) 365 | x = x.permute(1, 0, 2) # LND -> NLD 366 | x = self.ln_final(x).type(self.dtype) 367 | 368 | # x.shape = [batch_size, n_ctx, transformer.width] 369 | # take features from the eot embedding (eot_token is the highest number in each sequence) 370 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 371 | 372 | return x 373 | 374 | def forward(self, image, text): 375 | image_features = self.encode_image(image) 376 | text_features = self.encode_text(text) 377 | 378 | # normalized features 379 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 380 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 381 | 382 | # cosine similarity as logits 383 | logit_scale = self.logit_scale.exp() 384 | logits_per_image = logit_scale * image_features @ text_features.t() 385 | logits_per_text = logits_per_image.t() 386 | 387 | # shape = [global_batch_size, global_batch_size] 388 | return logits_per_image, logits_per_text 389 | 390 | 391 | def convert_weights(model: nn.Module): 392 | """Convert applicable model parameters to fp16""" 393 | 394 | def _convert_weights_to_fp16(l): 395 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 396 | l.weight.data = l.weight.data.half() 397 | if l.bias is not None: 398 | l.bias.data = l.bias.data.half() 399 | 400 | if isinstance(l, nn.MultiheadAttention): 401 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 402 | tensor = getattr(l, attr) 403 | if tensor is not None: 404 | tensor.data = tensor.data.half() 405 | 406 | for name in ["text_projection", "proj"]: 407 | if hasattr(l, name): 408 | attr = getattr(l, name) 409 | if attr is not None: 410 | attr.data = attr.data.half() 411 | 412 | model.apply(_convert_weights_to_fp16) 413 | 414 | 415 | def build_model(state_dict: dict): 416 | vit = "visual.proj" in state_dict 417 | 418 | if vit: 419 | vision_width = state_dict["visual.conv1.weight"].shape[0] 420 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 421 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 422 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 423 | image_resolution = vision_patch_size * grid_size 424 | else: 425 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 426 | vision_layers = tuple(counts) 427 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 428 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 429 | vision_patch_size = None 430 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 431 | image_resolution = output_width * 32 432 | 433 | embed_dim = state_dict["text_projection"].shape[1] 434 | context_length = state_dict["positional_embedding"].shape[0] 435 | vocab_size = state_dict["token_embedding.weight"].shape[0] 436 | transformer_width = state_dict["ln_final.weight"].shape[0] 437 | transformer_heads = transformer_width // 64 438 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 439 | 440 | model = CLIP( 441 | embed_dim, 442 | image_resolution, vision_layers, vision_width, vision_patch_size, 443 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 444 | ) 445 | 446 | for key in ["input_resolution", "context_length", "vocab_size"]: 447 | if key in state_dict: 448 | del state_dict[key] 449 | 450 | convert_weights(model) 451 | model.load_state_dict(state_dict) 452 | return model.eval() 453 | -------------------------------------------------------------------------------- /models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /models/clip_models.py: -------------------------------------------------------------------------------- 1 | from .clip import clip 2 | from PIL import Image 3 | import torch.nn as nn 4 | 5 | 6 | CHANNELS = { 7 | "RN50" : 1024, 8 | "ViT-L/14" : 768 9 | } 10 | 11 | class CLIPModel(nn.Module): 12 | def __init__(self, name, num_classes=1): 13 | super(CLIPModel, self).__init__() 14 | 15 | self.model, self.preprocess = clip.load(name, device="cpu") # self.preprecess will not be used during training, which is handled in Dataset class 16 | self.fc = nn.Linear( CHANNELS[name], num_classes ) 17 | 18 | 19 | def forward(self, x, return_feature=False): 20 | features = self.model.encode_image(x) 21 | if return_feature: 22 | return features 23 | return self.fc(features) 24 | 25 | -------------------------------------------------------------------------------- /models/imagenet_models.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 2 | from .vision_transformer import vit_b_16, vit_b_32, vit_l_16, vit_l_32 3 | 4 | from torchvision import transforms 5 | from PIL import Image 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | model_dict = { 11 | 'resnet18': resnet18, 12 | 'resnet34': resnet34, 13 | 'resnet50': resnet50, 14 | 'resnet101': resnet101, 15 | 'resnet152': resnet152, 16 | 'vit_b_16': vit_b_16, 17 | 'vit_b_32': vit_b_32, 18 | 'vit_l_16': vit_l_16, 19 | 'vit_l_32': vit_l_32 20 | } 21 | 22 | 23 | CHANNELS = { 24 | "resnet50" : 2048, 25 | "vit_b_16" : 768, 26 | } 27 | 28 | 29 | 30 | class ImagenetModel(nn.Module): 31 | def __init__(self, name, num_classes=1): 32 | super(ImagenetModel, self).__init__() 33 | 34 | self.model = model_dict[name](pretrained=True) 35 | self.fc = nn.Linear(CHANNELS[name], num_classes) #manually define a fc layer here 36 | 37 | 38 | def forward(self, x): 39 | feature = self.model(x)["penultimate"] 40 | return self.fc(feature) 41 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from typing import Type, Any, Callable, Union, List, Optional 5 | 6 | try: 7 | from torch.hub import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | } 23 | 24 | 25 | 26 | 27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion: int = 1 40 | 41 | def __init__( 42 | self, 43 | inplanes: int, 44 | planes: int, 45 | stride: int = 1, 46 | downsample: Optional[nn.Module] = None, 47 | groups: int = 1, 48 | base_width: int = 64, 49 | dilation: int = 1, 50 | norm_layer: Optional[Callable[..., nn.Module]] = None 51 | ) -> None: 52 | super(BasicBlock, self).__init__() 53 | if norm_layer is None: 54 | norm_layer = nn.BatchNorm2d 55 | if groups != 1 or base_width != 64: 56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 57 | if dilation > 1: 58 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 60 | self.conv1 = conv3x3(inplanes, planes, stride) 61 | self.bn1 = norm_layer(planes) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.conv2 = conv3x3(planes, planes) 64 | self.bn2 = norm_layer(planes) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | identity = self.downsample(x) 80 | 81 | out += identity 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class Bottleneck(nn.Module): 88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 91 | # This variant is also known as ResNet V1.5 and improves accuracy according to 92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 93 | 94 | expansion: int = 4 95 | 96 | def __init__( 97 | self, 98 | inplanes: int, 99 | planes: int, 100 | stride: int = 1, 101 | downsample: Optional[nn.Module] = None, 102 | groups: int = 1, 103 | base_width: int = 64, 104 | dilation: int = 1, 105 | norm_layer: Optional[Callable[..., nn.Module]] = None 106 | ) -> None: 107 | super(Bottleneck, self).__init__() 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | width = int(planes * (base_width / 64.)) * groups 111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 112 | self.conv1 = conv1x1(inplanes, width) 113 | self.bn1 = norm_layer(width) 114 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 115 | self.bn2 = norm_layer(width) 116 | self.conv3 = conv1x1(width, planes * self.expansion) 117 | self.bn3 = norm_layer(planes * self.expansion) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | 122 | def forward(self, x: Tensor) -> Tensor: 123 | identity = x 124 | 125 | out = self.conv1(x) 126 | out = self.bn1(out) 127 | out = self.relu(out) 128 | 129 | out = self.conv2(out) 130 | out = self.bn2(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv3(out) 134 | out = self.bn3(out) 135 | 136 | if self.downsample is not None: 137 | identity = self.downsample(x) 138 | 139 | out += identity 140 | out = self.relu(out) 141 | 142 | return out 143 | 144 | 145 | class ResNet(nn.Module): 146 | 147 | def __init__( 148 | self, 149 | block: Type[Union[BasicBlock, Bottleneck]], 150 | layers: List[int], 151 | num_classes: int = 1000, 152 | zero_init_residual: bool = False, 153 | groups: int = 1, 154 | width_per_group: int = 64, 155 | replace_stride_with_dilation: Optional[List[bool]] = None, 156 | norm_layer: Optional[Callable[..., nn.Module]] = None 157 | ) -> None: 158 | super(ResNet, self).__init__() 159 | if norm_layer is None: 160 | norm_layer = nn.BatchNorm2d 161 | self._norm_layer = norm_layer 162 | 163 | self.inplanes = 64 164 | self.dilation = 1 165 | if replace_stride_with_dilation is None: 166 | # each element in the tuple indicates if we should replace 167 | # the 2x2 stride with a dilated convolution instead 168 | replace_stride_with_dilation = [False, False, False] 169 | if len(replace_stride_with_dilation) != 3: 170 | raise ValueError("replace_stride_with_dilation should be None " 171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 172 | self.groups = groups 173 | self.base_width = width_per_group 174 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 175 | bias=False) 176 | self.bn1 = norm_layer(self.inplanes) 177 | self.relu = nn.ReLU(inplace=True) 178 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 179 | self.layer1 = self._make_layer(block, 64, layers[0]) 180 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 181 | dilate=replace_stride_with_dilation[0]) 182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 183 | dilate=replace_stride_with_dilation[1]) 184 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 185 | dilate=replace_stride_with_dilation[2]) 186 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 187 | self.fc = nn.Linear(512 * block.expansion, num_classes) 188 | 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 192 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 193 | nn.init.constant_(m.weight, 1) 194 | nn.init.constant_(m.bias, 0) 195 | 196 | # Zero-initialize the last BN in each residual branch, 197 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 198 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 199 | if zero_init_residual: 200 | for m in self.modules(): 201 | if isinstance(m, Bottleneck): 202 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 203 | elif isinstance(m, BasicBlock): 204 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 205 | 206 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 207 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 208 | norm_layer = self._norm_layer 209 | downsample = None 210 | previous_dilation = self.dilation 211 | if dilate: 212 | self.dilation *= stride 213 | stride = 1 214 | if stride != 1 or self.inplanes != planes * block.expansion: 215 | downsample = nn.Sequential( 216 | conv1x1(self.inplanes, planes * block.expansion, stride), 217 | norm_layer(planes * block.expansion), 218 | ) 219 | 220 | layers = [] 221 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 222 | self.base_width, previous_dilation, norm_layer)) 223 | self.inplanes = planes * block.expansion 224 | for _ in range(1, blocks): 225 | layers.append(block(self.inplanes, planes, groups=self.groups, 226 | base_width=self.base_width, dilation=self.dilation, 227 | norm_layer=norm_layer)) 228 | 229 | return nn.Sequential(*layers) 230 | 231 | def _forward_impl(self, x): 232 | # The comment resolution is based on input size is 224*224 imagenet 233 | out = {} 234 | x = self.conv1(x) 235 | x = self.bn1(x) 236 | x = self.relu(x) 237 | x = self.maxpool(x) 238 | out['f0'] = x # N*64*56*56 239 | 240 | x = self.layer1(x) 241 | out['f1'] = x # N*64*56*56 242 | 243 | x = self.layer2(x) 244 | out['f2'] = x # N*128*28*28 245 | 246 | x = self.layer3(x) 247 | out['f3'] = x # N*256*14*14 248 | 249 | x = self.layer4(x) 250 | out['f4'] = x # N*512*7*7 251 | 252 | x = self.avgpool(x) 253 | x = torch.flatten(x, 1) 254 | out['penultimate'] = x # N*512 255 | 256 | x = self.fc(x) 257 | out['logits'] = x # N*1000 258 | 259 | # return all features 260 | return out 261 | 262 | # return final classification result 263 | # return x 264 | 265 | def forward(self, x): 266 | return self._forward_impl(x) 267 | 268 | 269 | def _resnet( 270 | arch: str, 271 | block: Type[Union[BasicBlock, Bottleneck]], 272 | layers: List[int], 273 | pretrained: bool, 274 | progress: bool, 275 | **kwargs: Any 276 | ) -> ResNet: 277 | model = ResNet(block, layers, **kwargs) 278 | if pretrained: 279 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 280 | model.load_state_dict(state_dict) 281 | return model 282 | 283 | 284 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 285 | r"""ResNet-18 model from 286 | `"Deep Residual Learning for Image Recognition" `_. 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 293 | 294 | 295 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 296 | r"""ResNet-34 model from 297 | `"Deep Residual Learning for Image Recognition" `_. 298 | 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) 304 | 305 | 306 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 307 | r"""ResNet-50 model from 308 | `"Deep Residual Learning for Image Recognition" `_. 309 | 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 315 | 316 | 317 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 318 | r"""ResNet-101 model from 319 | `"Deep Residual Learning for Image Recognition" `_. 320 | 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 326 | 327 | 328 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 329 | r"""ResNet-152 model from 330 | `"Deep Residual Learning for Image Recognition" `_. 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) 337 | 338 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Union, List, Dict, Any, cast 4 | import torchvision 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | 10 | 11 | class VGG(torch.nn.Module): 12 | def __init__(self, arch_type, pretrained, progress): 13 | super().__init__() 14 | 15 | self.layer1 = torch.nn.Sequential() 16 | self.layer2 = torch.nn.Sequential() 17 | self.layer3 = torch.nn.Sequential() 18 | self.layer4 = torch.nn.Sequential() 19 | self.layer5 = torch.nn.Sequential() 20 | 21 | if arch_type == 'vgg11': 22 | official_vgg = torchvision.models.vgg11(pretrained=pretrained, progress=progress) 23 | blocks = [ [0,2], [2,5], [5,10], [10,15], [15,20] ] 24 | last_idx = 20 25 | elif arch_type == 'vgg19': 26 | official_vgg = torchvision.models.vgg19(pretrained=pretrained, progress=progress) 27 | blocks = [ [0,4], [4,9], [9,18], [18,27], [27,36] ] 28 | last_idx = 36 29 | else: 30 | raise NotImplementedError 31 | 32 | 33 | for x in range( *blocks[0] ): 34 | self.layer1.add_module(str(x), official_vgg.features[x]) 35 | for x in range( *blocks[1] ): 36 | self.layer2.add_module(str(x), official_vgg.features[x]) 37 | for x in range( *blocks[2] ): 38 | self.layer3.add_module(str(x), official_vgg.features[x]) 39 | for x in range( *blocks[3] ): 40 | self.layer4.add_module(str(x), official_vgg.features[x]) 41 | for x in range( *blocks[4] ): 42 | self.layer5.add_module(str(x), official_vgg.features[x]) 43 | 44 | self.max_pool = official_vgg.features[last_idx] 45 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 46 | 47 | self.fc1 = official_vgg.classifier[0] 48 | self.fc2 = official_vgg.classifier[3] 49 | self.fc3 = official_vgg.classifier[6] 50 | self.dropout = nn.Dropout() 51 | 52 | 53 | def forward(self, x): 54 | out = {} 55 | 56 | x = self.layer1(x) 57 | out['f0'] = x 58 | 59 | x = self.layer2(x) 60 | out['f1'] = x 61 | 62 | x = self.layer3(x) 63 | out['f2'] = x 64 | 65 | x = self.layer4(x) 66 | out['f3'] = x 67 | 68 | x = self.layer5(x) 69 | out['f4'] = x 70 | 71 | x = self.max_pool(x) 72 | x = self.avgpool(x) 73 | x = x.view(-1,512*7*7) 74 | 75 | x = self.fc1(x) 76 | x = F.relu(x) 77 | x = self.dropout(x) 78 | x = self.fc2(x) 79 | x = F.relu(x) 80 | out['penultimate'] = x 81 | x = self.dropout(x) 82 | x = self.fc3(x) 83 | out['logits'] = x 84 | 85 | return out 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | def vgg11(pretrained=False, progress=True): 97 | r"""VGG 11-layer model (configuration "A") from 98 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 99 | 100 | Args: 101 | pretrained (bool): If True, returns a model pre-trained on ImageNet 102 | progress (bool): If True, displays a progress bar of the download to stderr 103 | """ 104 | return VGG('vgg11', pretrained, progress) 105 | 106 | 107 | 108 | def vgg19(pretrained=False, progress=True): 109 | r"""VGG 19-layer model (configuration "E") 110 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 111 | 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | progress (bool): If True, displays a progress bar of the download to stderr 115 | """ 116 | return VGG('vgg19', pretrained, progress) 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | from functools import partial 4 | from typing import Any, Callable, List, NamedTuple, Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | # from .._internally_replaced_utils import load_state_dict_from_url 10 | from .vision_transformer_misc import ConvNormActivation 11 | from .vision_transformer_utils import _log_api_usage_once 12 | 13 | try: 14 | from torch.hub import load_state_dict_from_url 15 | except ImportError: 16 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 17 | 18 | # __all__ = [ 19 | # "VisionTransformer", 20 | # "vit_b_16", 21 | # "vit_b_32", 22 | # "vit_l_16", 23 | # "vit_l_32", 24 | # ] 25 | 26 | model_urls = { 27 | "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", 28 | "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", 29 | "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", 30 | "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", 31 | } 32 | 33 | 34 | class ConvStemConfig(NamedTuple): 35 | out_channels: int 36 | kernel_size: int 37 | stride: int 38 | norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d 39 | activation_layer: Callable[..., nn.Module] = nn.ReLU 40 | 41 | 42 | class MLPBlock(nn.Sequential): 43 | """Transformer MLP block.""" 44 | 45 | def __init__(self, in_dim: int, mlp_dim: int, dropout: float): 46 | super().__init__() 47 | self.linear_1 = nn.Linear(in_dim, mlp_dim) 48 | self.act = nn.GELU() 49 | self.dropout_1 = nn.Dropout(dropout) 50 | self.linear_2 = nn.Linear(mlp_dim, in_dim) 51 | self.dropout_2 = nn.Dropout(dropout) 52 | 53 | nn.init.xavier_uniform_(self.linear_1.weight) 54 | nn.init.xavier_uniform_(self.linear_2.weight) 55 | nn.init.normal_(self.linear_1.bias, std=1e-6) 56 | nn.init.normal_(self.linear_2.bias, std=1e-6) 57 | 58 | 59 | class EncoderBlock(nn.Module): 60 | """Transformer encoder block.""" 61 | 62 | def __init__( 63 | self, 64 | num_heads: int, 65 | hidden_dim: int, 66 | mlp_dim: int, 67 | dropout: float, 68 | attention_dropout: float, 69 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 70 | ): 71 | super().__init__() 72 | self.num_heads = num_heads 73 | 74 | # Attention block 75 | self.ln_1 = norm_layer(hidden_dim) 76 | self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) 77 | self.dropout = nn.Dropout(dropout) 78 | 79 | # MLP block 80 | self.ln_2 = norm_layer(hidden_dim) 81 | self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) 82 | 83 | def forward(self, input: torch.Tensor): 84 | torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") 85 | x = self.ln_1(input) 86 | x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) 87 | x = self.dropout(x) 88 | x = x + input 89 | 90 | y = self.ln_2(x) 91 | y = self.mlp(y) 92 | return x + y 93 | 94 | 95 | class Encoder(nn.Module): 96 | """Transformer Model Encoder for sequence to sequence translation.""" 97 | 98 | def __init__( 99 | self, 100 | seq_length: int, 101 | num_layers: int, 102 | num_heads: int, 103 | hidden_dim: int, 104 | mlp_dim: int, 105 | dropout: float, 106 | attention_dropout: float, 107 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 108 | ): 109 | super().__init__() 110 | # Note that batch_size is on the first dim because 111 | # we have batch_first=True in nn.MultiAttention() by default 112 | self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT 113 | self.dropout = nn.Dropout(dropout) 114 | layers: OrderedDict[str, nn.Module] = OrderedDict() 115 | for i in range(num_layers): 116 | layers[f"encoder_layer_{i}"] = EncoderBlock( 117 | num_heads, 118 | hidden_dim, 119 | mlp_dim, 120 | dropout, 121 | attention_dropout, 122 | norm_layer, 123 | ) 124 | self.layers = nn.Sequential(layers) 125 | self.ln = norm_layer(hidden_dim) 126 | 127 | def forward(self, input: torch.Tensor): 128 | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") 129 | input = input + self.pos_embedding 130 | return self.ln(self.layers(self.dropout(input))) 131 | 132 | 133 | class VisionTransformer(nn.Module): 134 | """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" 135 | 136 | def __init__( 137 | self, 138 | image_size: int, 139 | patch_size: int, 140 | num_layers: int, 141 | num_heads: int, 142 | hidden_dim: int, 143 | mlp_dim: int, 144 | dropout: float = 0.0, 145 | attention_dropout: float = 0.0, 146 | num_classes: int = 1000, 147 | representation_size: Optional[int] = None, 148 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 149 | conv_stem_configs: Optional[List[ConvStemConfig]] = None, 150 | ): 151 | super().__init__() 152 | _log_api_usage_once(self) 153 | torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") 154 | self.image_size = image_size 155 | self.patch_size = patch_size 156 | self.hidden_dim = hidden_dim 157 | self.mlp_dim = mlp_dim 158 | self.attention_dropout = attention_dropout 159 | self.dropout = dropout 160 | self.num_classes = num_classes 161 | self.representation_size = representation_size 162 | self.norm_layer = norm_layer 163 | 164 | if conv_stem_configs is not None: 165 | # As per https://arxiv.org/abs/2106.14881 166 | seq_proj = nn.Sequential() 167 | prev_channels = 3 168 | for i, conv_stem_layer_config in enumerate(conv_stem_configs): 169 | seq_proj.add_module( 170 | f"conv_bn_relu_{i}", 171 | ConvNormActivation( 172 | in_channels=prev_channels, 173 | out_channels=conv_stem_layer_config.out_channels, 174 | kernel_size=conv_stem_layer_config.kernel_size, 175 | stride=conv_stem_layer_config.stride, 176 | norm_layer=conv_stem_layer_config.norm_layer, 177 | activation_layer=conv_stem_layer_config.activation_layer, 178 | ), 179 | ) 180 | prev_channels = conv_stem_layer_config.out_channels 181 | seq_proj.add_module( 182 | "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) 183 | ) 184 | self.conv_proj: nn.Module = seq_proj 185 | else: 186 | self.conv_proj = nn.Conv2d( 187 | in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size 188 | ) 189 | 190 | seq_length = (image_size // patch_size) ** 2 191 | 192 | # Add a class token 193 | self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) 194 | seq_length += 1 195 | 196 | self.encoder = Encoder( 197 | seq_length, 198 | num_layers, 199 | num_heads, 200 | hidden_dim, 201 | mlp_dim, 202 | dropout, 203 | attention_dropout, 204 | norm_layer, 205 | ) 206 | self.seq_length = seq_length 207 | 208 | heads_layers: OrderedDict[str, nn.Module] = OrderedDict() 209 | if representation_size is None: 210 | heads_layers["head"] = nn.Linear(hidden_dim, num_classes) 211 | else: 212 | heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) 213 | heads_layers["act"] = nn.Tanh() 214 | heads_layers["head"] = nn.Linear(representation_size, num_classes) 215 | 216 | self.heads = nn.Sequential(heads_layers) 217 | 218 | if isinstance(self.conv_proj, nn.Conv2d): 219 | # Init the patchify stem 220 | fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] 221 | nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) 222 | if self.conv_proj.bias is not None: 223 | nn.init.zeros_(self.conv_proj.bias) 224 | elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): 225 | # Init the last 1x1 conv of the conv stem 226 | nn.init.normal_( 227 | self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) 228 | ) 229 | if self.conv_proj.conv_last.bias is not None: 230 | nn.init.zeros_(self.conv_proj.conv_last.bias) 231 | 232 | if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): 233 | fan_in = self.heads.pre_logits.in_features 234 | nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) 235 | nn.init.zeros_(self.heads.pre_logits.bias) 236 | 237 | if isinstance(self.heads.head, nn.Linear): 238 | nn.init.zeros_(self.heads.head.weight) 239 | nn.init.zeros_(self.heads.head.bias) 240 | 241 | def _process_input(self, x: torch.Tensor) -> torch.Tensor: 242 | n, c, h, w = x.shape 243 | p = self.patch_size 244 | torch._assert(h == self.image_size, "Wrong image height!") 245 | torch._assert(w == self.image_size, "Wrong image width!") 246 | n_h = h // p 247 | n_w = w // p 248 | 249 | # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) 250 | x = self.conv_proj(x) 251 | # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) 252 | x = x.reshape(n, self.hidden_dim, n_h * n_w) 253 | 254 | # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) 255 | # The self attention layer expects inputs in the format (N, S, E) 256 | # where S is the source sequence length, N is the batch size, E is the 257 | # embedding dimension 258 | x = x.permute(0, 2, 1) 259 | 260 | return x 261 | 262 | def forward(self, x: torch.Tensor): 263 | out = {} 264 | 265 | # Reshape and permute the input tensor 266 | x = self._process_input(x) 267 | n = x.shape[0] 268 | 269 | # Expand the class token to the full batch 270 | batch_class_token = self.class_token.expand(n, -1, -1) 271 | x = torch.cat([batch_class_token, x], dim=1) 272 | 273 | 274 | x = self.encoder(x) 275 | img_feature = x[:,1:] 276 | H = W = int(self.image_size / self.patch_size) 277 | out['f4'] = img_feature.view(n, H, W, self.hidden_dim).permute(0,3,1,2) 278 | 279 | # Classifier "token" as used by standard language architectures 280 | x = x[:, 0] 281 | out['penultimate'] = x 282 | 283 | x = self.heads(x) # I checked that for all pretrained ViT, this is just a fc 284 | out['logits'] = x 285 | 286 | return out 287 | 288 | 289 | def _vision_transformer( 290 | arch: str, 291 | patch_size: int, 292 | num_layers: int, 293 | num_heads: int, 294 | hidden_dim: int, 295 | mlp_dim: int, 296 | pretrained: bool, 297 | progress: bool, 298 | **kwargs: Any, 299 | ) -> VisionTransformer: 300 | image_size = kwargs.pop("image_size", 224) 301 | 302 | model = VisionTransformer( 303 | image_size=image_size, 304 | patch_size=patch_size, 305 | num_layers=num_layers, 306 | num_heads=num_heads, 307 | hidden_dim=hidden_dim, 308 | mlp_dim=mlp_dim, 309 | **kwargs, 310 | ) 311 | 312 | if pretrained: 313 | if arch not in model_urls: 314 | raise ValueError(f"No checkpoint is available for model type '{arch}'!") 315 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 316 | model.load_state_dict(state_dict) 317 | 318 | return model 319 | 320 | 321 | def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: 322 | """ 323 | Constructs a vit_b_16 architecture from 324 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. 325 | 326 | Args: 327 | pretrained (bool): If True, returns a model pre-trained on ImageNet 328 | progress (bool): If True, displays a progress bar of the download to stderr 329 | """ 330 | return _vision_transformer( 331 | arch="vit_b_16", 332 | patch_size=16, 333 | num_layers=12, 334 | num_heads=12, 335 | hidden_dim=768, 336 | mlp_dim=3072, 337 | pretrained=pretrained, 338 | progress=progress, 339 | **kwargs, 340 | ) 341 | 342 | 343 | def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: 344 | """ 345 | Constructs a vit_b_32 architecture from 346 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. 347 | 348 | Args: 349 | pretrained (bool): If True, returns a model pre-trained on ImageNet 350 | progress (bool): If True, displays a progress bar of the download to stderr 351 | """ 352 | return _vision_transformer( 353 | arch="vit_b_32", 354 | patch_size=32, 355 | num_layers=12, 356 | num_heads=12, 357 | hidden_dim=768, 358 | mlp_dim=3072, 359 | pretrained=pretrained, 360 | progress=progress, 361 | **kwargs, 362 | ) 363 | 364 | 365 | def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: 366 | """ 367 | Constructs a vit_l_16 architecture from 368 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. 369 | 370 | Args: 371 | pretrained (bool): If True, returns a model pre-trained on ImageNet 372 | progress (bool): If True, displays a progress bar of the download to stderr 373 | """ 374 | return _vision_transformer( 375 | arch="vit_l_16", 376 | patch_size=16, 377 | num_layers=24, 378 | num_heads=16, 379 | hidden_dim=1024, 380 | mlp_dim=4096, 381 | pretrained=pretrained, 382 | progress=progress, 383 | **kwargs, 384 | ) 385 | 386 | 387 | def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: 388 | """ 389 | Constructs a vit_l_32 architecture from 390 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. 391 | 392 | Args: 393 | pretrained (bool): If True, returns a model pre-trained on ImageNet 394 | progress (bool): If True, displays a progress bar of the download to stderr 395 | """ 396 | return _vision_transformer( 397 | arch="vit_l_32", 398 | patch_size=32, 399 | num_layers=24, 400 | num_heads=16, 401 | hidden_dim=1024, 402 | mlp_dim=4096, 403 | pretrained=pretrained, 404 | progress=progress, 405 | **kwargs, 406 | ) 407 | 408 | 409 | def interpolate_embeddings( 410 | image_size: int, 411 | patch_size: int, 412 | model_state: "OrderedDict[str, torch.Tensor]", 413 | interpolation_mode: str = "bicubic", 414 | reset_heads: bool = False, 415 | ) -> "OrderedDict[str, torch.Tensor]": 416 | """This function helps interpolating positional embeddings during checkpoint loading, 417 | especially when you want to apply a pre-trained model on images with different resolution. 418 | 419 | Args: 420 | image_size (int): Image size of the new model. 421 | patch_size (int): Patch size of the new model. 422 | model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. 423 | interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. 424 | reset_heads (bool): If true, not copying the state of heads. Default: False. 425 | 426 | Returns: 427 | OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. 428 | """ 429 | # Shape of pos_embedding is (1, seq_length, hidden_dim) 430 | pos_embedding = model_state["encoder.pos_embedding"] 431 | n, seq_length, hidden_dim = pos_embedding.shape 432 | if n != 1: 433 | raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") 434 | 435 | new_seq_length = (image_size // patch_size) ** 2 + 1 436 | 437 | # Need to interpolate the weights for the position embedding. 438 | # We do this by reshaping the positions embeddings to a 2d grid, performing 439 | # an interpolation in the (h, w) space and then reshaping back to a 1d grid. 440 | if new_seq_length != seq_length: 441 | # The class token embedding shouldn't be interpolated so we split it up. 442 | seq_length -= 1 443 | new_seq_length -= 1 444 | pos_embedding_token = pos_embedding[:, :1, :] 445 | pos_embedding_img = pos_embedding[:, 1:, :] 446 | 447 | # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) 448 | pos_embedding_img = pos_embedding_img.permute(0, 2, 1) 449 | seq_length_1d = int(math.sqrt(seq_length)) 450 | torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") 451 | 452 | # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) 453 | pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) 454 | new_seq_length_1d = image_size // patch_size 455 | 456 | # Perform interpolation. 457 | # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) 458 | new_pos_embedding_img = nn.functional.interpolate( 459 | pos_embedding_img, 460 | size=new_seq_length_1d, 461 | mode=interpolation_mode, 462 | align_corners=True, 463 | ) 464 | 465 | # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) 466 | new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) 467 | 468 | # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) 469 | new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) 470 | new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) 471 | 472 | model_state["encoder.pos_embedding"] = new_pos_embedding 473 | 474 | if reset_heads: 475 | model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() 476 | for k, v in model_state.items(): 477 | if not k.startswith("heads"): 478 | model_state_copy[k] = v 479 | model_state = model_state_copy 480 | 481 | return model_state 482 | -------------------------------------------------------------------------------- /models/vision_transformer_misc.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from .vision_transformer_utils import _log_api_usage_once 7 | 8 | 9 | interpolate = torch.nn.functional.interpolate 10 | 11 | 12 | # This is not in nn 13 | class FrozenBatchNorm2d(torch.nn.Module): 14 | """ 15 | BatchNorm2d where the batch statistics and the affine parameters are fixed 16 | 17 | Args: 18 | num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` 19 | eps (float): a value added to the denominator for numerical stability. Default: 1e-5 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_features: int, 25 | eps: float = 1e-5, 26 | ): 27 | super().__init__() 28 | _log_api_usage_once(self) 29 | self.eps = eps 30 | self.register_buffer("weight", torch.ones(num_features)) 31 | self.register_buffer("bias", torch.zeros(num_features)) 32 | self.register_buffer("running_mean", torch.zeros(num_features)) 33 | self.register_buffer("running_var", torch.ones(num_features)) 34 | 35 | def _load_from_state_dict( 36 | self, 37 | state_dict: dict, 38 | prefix: str, 39 | local_metadata: dict, 40 | strict: bool, 41 | missing_keys: List[str], 42 | unexpected_keys: List[str], 43 | error_msgs: List[str], 44 | ): 45 | num_batches_tracked_key = prefix + "num_batches_tracked" 46 | if num_batches_tracked_key in state_dict: 47 | del state_dict[num_batches_tracked_key] 48 | 49 | super()._load_from_state_dict( 50 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 51 | ) 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | # move reshapes to the beginning 55 | # to make it fuser-friendly 56 | w = self.weight.reshape(1, -1, 1, 1) 57 | b = self.bias.reshape(1, -1, 1, 1) 58 | rv = self.running_var.reshape(1, -1, 1, 1) 59 | rm = self.running_mean.reshape(1, -1, 1, 1) 60 | scale = w * (rv + self.eps).rsqrt() 61 | bias = b - rm * scale 62 | return x * scale + bias 63 | 64 | def __repr__(self) -> str: 65 | return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" 66 | 67 | 68 | class ConvNormActivation(torch.nn.Sequential): 69 | """ 70 | Configurable block used for Convolution-Normalzation-Activation blocks. 71 | 72 | Args: 73 | in_channels (int): Number of channels in the input image 74 | out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block 75 | kernel_size: (int, optional): Size of the convolving kernel. Default: 3 76 | stride (int, optional): Stride of the convolution. Default: 1 77 | padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` 78 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 79 | norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` 80 | activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` 81 | dilation (int): Spacing between kernel elements. Default: 1 82 | inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` 83 | bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. 84 | 85 | """ 86 | 87 | def __init__( 88 | self, 89 | in_channels: int, 90 | out_channels: int, 91 | kernel_size: int = 3, 92 | stride: int = 1, 93 | padding: Optional[int] = None, 94 | groups: int = 1, 95 | norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, 96 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, 97 | dilation: int = 1, 98 | inplace: Optional[bool] = True, 99 | bias: Optional[bool] = None, 100 | ) -> None: 101 | if padding is None: 102 | padding = (kernel_size - 1) // 2 * dilation 103 | if bias is None: 104 | bias = norm_layer is None 105 | layers = [ 106 | torch.nn.Conv2d( 107 | in_channels, 108 | out_channels, 109 | kernel_size, 110 | stride, 111 | padding, 112 | dilation=dilation, 113 | groups=groups, 114 | bias=bias, 115 | ) 116 | ] 117 | if norm_layer is not None: 118 | layers.append(norm_layer(out_channels)) 119 | if activation_layer is not None: 120 | params = {} if inplace is None else {"inplace": inplace} 121 | layers.append(activation_layer(**params)) 122 | super().__init__(*layers) 123 | _log_api_usage_once(self) 124 | self.out_channels = out_channels 125 | 126 | 127 | class SqueezeExcitation(torch.nn.Module): 128 | """ 129 | This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). 130 | Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. 131 | 132 | Args: 133 | input_channels (int): Number of channels in the input image 134 | squeeze_channels (int): Number of squeeze channels 135 | activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` 136 | scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` 137 | """ 138 | 139 | def __init__( 140 | self, 141 | input_channels: int, 142 | squeeze_channels: int, 143 | activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, 144 | scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, 145 | ) -> None: 146 | super().__init__() 147 | _log_api_usage_once(self) 148 | self.avgpool = torch.nn.AdaptiveAvgPool2d(1) 149 | self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) 150 | self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) 151 | self.activation = activation() 152 | self.scale_activation = scale_activation() 153 | 154 | def _scale(self, input: Tensor) -> Tensor: 155 | scale = self.avgpool(input) 156 | scale = self.fc1(scale) 157 | scale = self.activation(scale) 158 | scale = self.fc2(scale) 159 | return self.scale_activation(scale) 160 | 161 | def forward(self, input: Tensor) -> Tensor: 162 | scale = self._scale(input) 163 | return scale * input 164 | -------------------------------------------------------------------------------- /models/vision_transformer_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pathlib 3 | import warnings 4 | from types import FunctionType 5 | from typing import Any, BinaryIO, List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image, ImageColor, ImageDraw, ImageFont 10 | 11 | __all__ = [ 12 | "make_grid", 13 | "save_image", 14 | "draw_bounding_boxes", 15 | "draw_segmentation_masks", 16 | "draw_keypoints", 17 | "flow_to_image", 18 | ] 19 | 20 | 21 | @torch.no_grad() 22 | def make_grid( 23 | tensor: Union[torch.Tensor, List[torch.Tensor]], 24 | nrow: int = 8, 25 | padding: int = 2, 26 | normalize: bool = False, 27 | value_range: Optional[Tuple[int, int]] = None, 28 | scale_each: bool = False, 29 | pad_value: float = 0.0, 30 | **kwargs, 31 | ) -> torch.Tensor: 32 | """ 33 | Make a grid of images. 34 | 35 | Args: 36 | tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) 37 | or a list of images all of the same size. 38 | nrow (int, optional): Number of images displayed in each row of the grid. 39 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``. 40 | padding (int, optional): amount of padding. Default: ``2``. 41 | normalize (bool, optional): If True, shift the image to the range (0, 1), 42 | by the min and max values specified by ``value_range``. Default: ``False``. 43 | value_range (tuple, optional): tuple (min, max) where min and max are numbers, 44 | then these numbers are used to normalize the image. By default, min and max 45 | are computed from the tensor. 46 | range (tuple. optional): 47 | .. warning:: 48 | This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` 49 | instead. 50 | scale_each (bool, optional): If ``True``, scale each image in the batch of 51 | images separately rather than the (min, max) over all images. Default: ``False``. 52 | pad_value (float, optional): Value for the padded pixels. Default: ``0``. 53 | 54 | Returns: 55 | grid (Tensor): the tensor containing grid of images. 56 | """ 57 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 58 | _log_api_usage_once(make_grid) 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") 61 | 62 | if "range" in kwargs.keys(): 63 | warnings.warn( 64 | "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " 65 | "Please use 'value_range' instead." 66 | ) 67 | value_range = kwargs["range"] 68 | 69 | # if list of tensors, convert to a 4D mini-batch Tensor 70 | if isinstance(tensor, list): 71 | tensor = torch.stack(tensor, dim=0) 72 | 73 | if tensor.dim() == 2: # single image H x W 74 | tensor = tensor.unsqueeze(0) 75 | if tensor.dim() == 3: # single image 76 | if tensor.size(0) == 1: # if single-channel, convert to 3-channel 77 | tensor = torch.cat((tensor, tensor, tensor), 0) 78 | tensor = tensor.unsqueeze(0) 79 | 80 | if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images 81 | tensor = torch.cat((tensor, tensor, tensor), 1) 82 | 83 | if normalize is True: 84 | tensor = tensor.clone() # avoid modifying tensor in-place 85 | if value_range is not None: 86 | assert isinstance( 87 | value_range, tuple 88 | ), "value_range has to be a tuple (min, max) if specified. min and max are numbers" 89 | 90 | def norm_ip(img, low, high): 91 | img.clamp_(min=low, max=high) 92 | img.sub_(low).div_(max(high - low, 1e-5)) 93 | 94 | def norm_range(t, value_range): 95 | if value_range is not None: 96 | norm_ip(t, value_range[0], value_range[1]) 97 | else: 98 | norm_ip(t, float(t.min()), float(t.max())) 99 | 100 | if scale_each is True: 101 | for t in tensor: # loop over mini-batch dimension 102 | norm_range(t, value_range) 103 | else: 104 | norm_range(tensor, value_range) 105 | 106 | assert isinstance(tensor, torch.Tensor) 107 | if tensor.size(0) == 1: 108 | return tensor.squeeze(0) 109 | 110 | # make the mini-batch of images into a grid 111 | nmaps = tensor.size(0) 112 | xmaps = min(nrow, nmaps) 113 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 114 | height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) 115 | num_channels = tensor.size(1) 116 | grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) 117 | k = 0 118 | for y in range(ymaps): 119 | for x in range(xmaps): 120 | if k >= nmaps: 121 | break 122 | # Tensor.copy_() is a valid method but seems to be missing from the stubs 123 | # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ 124 | grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] 125 | 2, x * width + padding, width - padding 126 | ).copy_(tensor[k]) 127 | k = k + 1 128 | return grid 129 | 130 | 131 | @torch.no_grad() 132 | def save_image( 133 | tensor: Union[torch.Tensor, List[torch.Tensor]], 134 | fp: Union[str, pathlib.Path, BinaryIO], 135 | format: Optional[str] = None, 136 | **kwargs, 137 | ) -> None: 138 | """ 139 | Save a given Tensor into an image file. 140 | 141 | Args: 142 | tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, 143 | saves the tensor as a grid of images by calling ``make_grid``. 144 | fp (string or file object): A filename or a file object 145 | format(Optional): If omitted, the format to use is determined from the filename extension. 146 | If a file object was used instead of a filename, this parameter should always be used. 147 | **kwargs: Other arguments are documented in ``make_grid``. 148 | """ 149 | 150 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 151 | _log_api_usage_once(save_image) 152 | grid = make_grid(tensor, **kwargs) 153 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 154 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 155 | im = Image.fromarray(ndarr) 156 | im.save(fp, format=format) 157 | 158 | 159 | @torch.no_grad() 160 | def draw_bounding_boxes( 161 | image: torch.Tensor, 162 | boxes: torch.Tensor, 163 | labels: Optional[List[str]] = None, 164 | colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, 165 | fill: Optional[bool] = False, 166 | width: int = 1, 167 | font: Optional[str] = None, 168 | font_size: int = 10, 169 | ) -> torch.Tensor: 170 | 171 | """ 172 | Draws bounding boxes on given image. 173 | The values of the input image should be uint8 between 0 and 255. 174 | If fill is True, Resulting Tensor should be saved as PNG image. 175 | 176 | Args: 177 | image (Tensor): Tensor of shape (C x H x W) and dtype uint8. 178 | boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that 179 | the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and 180 | `0 <= ymin < ymax < H`. 181 | labels (List[str]): List containing the labels of bounding boxes. 182 | colors (color or list of colors, optional): List containing the colors 183 | of the boxes or single color for all boxes. The color can be represented as 184 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. 185 | By default, random colors are generated for boxes. 186 | fill (bool): If `True` fills the bounding box with specified color. 187 | width (int): Width of bounding box. 188 | font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may 189 | also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, 190 | `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. 191 | font_size (int): The requested font size in points. 192 | 193 | Returns: 194 | img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. 195 | """ 196 | 197 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 198 | _log_api_usage_once(draw_bounding_boxes) 199 | if not isinstance(image, torch.Tensor): 200 | raise TypeError(f"Tensor expected, got {type(image)}") 201 | elif image.dtype != torch.uint8: 202 | raise ValueError(f"Tensor uint8 expected, got {image.dtype}") 203 | elif image.dim() != 3: 204 | raise ValueError("Pass individual images, not batches") 205 | elif image.size(0) not in {1, 3}: 206 | raise ValueError("Only grayscale and RGB images are supported") 207 | 208 | num_boxes = boxes.shape[0] 209 | 210 | if labels is None: 211 | labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] 212 | elif len(labels) != num_boxes: 213 | raise ValueError( 214 | f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." 215 | ) 216 | 217 | if colors is None: 218 | colors = _generate_color_palette(num_boxes) 219 | elif isinstance(colors, list): 220 | if len(colors) < num_boxes: 221 | raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") 222 | else: # colors specifies a single color for all boxes 223 | colors = [colors] * num_boxes 224 | 225 | colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] 226 | 227 | # Handle Grayscale images 228 | if image.size(0) == 1: 229 | image = torch.tile(image, (3, 1, 1)) 230 | 231 | ndarr = image.permute(1, 2, 0).cpu().numpy() 232 | img_to_draw = Image.fromarray(ndarr) 233 | img_boxes = boxes.to(torch.int64).tolist() 234 | 235 | if fill: 236 | draw = ImageDraw.Draw(img_to_draw, "RGBA") 237 | else: 238 | draw = ImageDraw.Draw(img_to_draw) 239 | 240 | txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) 241 | 242 | for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] 243 | if fill: 244 | fill_color = color + (100,) 245 | draw.rectangle(bbox, width=width, outline=color, fill=fill_color) 246 | else: 247 | draw.rectangle(bbox, width=width, outline=color) 248 | 249 | if label is not None: 250 | margin = width + 1 251 | draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) 252 | 253 | return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) 254 | 255 | 256 | @torch.no_grad() 257 | def draw_segmentation_masks( 258 | image: torch.Tensor, 259 | masks: torch.Tensor, 260 | alpha: float = 0.8, 261 | colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, 262 | ) -> torch.Tensor: 263 | 264 | """ 265 | Draws segmentation masks on given RGB image. 266 | The values of the input image should be uint8 between 0 and 255. 267 | 268 | Args: 269 | image (Tensor): Tensor of shape (3, H, W) and dtype uint8. 270 | masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. 271 | alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 272 | 0 means full transparency, 1 means no transparency. 273 | colors (color or list of colors, optional): List containing the colors 274 | of the masks or single color for all masks. The color can be represented as 275 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. 276 | By default, random colors are generated for each mask. 277 | 278 | Returns: 279 | img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. 280 | """ 281 | 282 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 283 | _log_api_usage_once(draw_segmentation_masks) 284 | if not isinstance(image, torch.Tensor): 285 | raise TypeError(f"The image must be a tensor, got {type(image)}") 286 | elif image.dtype != torch.uint8: 287 | raise ValueError(f"The image dtype must be uint8, got {image.dtype}") 288 | elif image.dim() != 3: 289 | raise ValueError("Pass individual images, not batches") 290 | elif image.size()[0] != 3: 291 | raise ValueError("Pass an RGB image. Other Image formats are not supported") 292 | if masks.ndim == 2: 293 | masks = masks[None, :, :] 294 | if masks.ndim != 3: 295 | raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") 296 | if masks.dtype != torch.bool: 297 | raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") 298 | if masks.shape[-2:] != image.shape[-2:]: 299 | raise ValueError("The image and the masks must have the same height and width") 300 | 301 | num_masks = masks.size()[0] 302 | if colors is not None and num_masks > len(colors): 303 | raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") 304 | 305 | if colors is None: 306 | colors = _generate_color_palette(num_masks) 307 | 308 | if not isinstance(colors, list): 309 | colors = [colors] 310 | if not isinstance(colors[0], (tuple, str)): 311 | raise ValueError("colors must be a tuple or a string, or a list thereof") 312 | if isinstance(colors[0], tuple) and len(colors[0]) != 3: 313 | raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") 314 | 315 | out_dtype = torch.uint8 316 | 317 | colors_ = [] 318 | for color in colors: 319 | if isinstance(color, str): 320 | color = ImageColor.getrgb(color) 321 | colors_.append(torch.tensor(color, dtype=out_dtype)) 322 | 323 | img_to_draw = image.detach().clone() 324 | # TODO: There might be a way to vectorize this 325 | for mask, color in zip(masks, colors_): 326 | img_to_draw[:, mask] = color[:, None] 327 | 328 | out = image * (1 - alpha) + img_to_draw * alpha 329 | return out.to(out_dtype) 330 | 331 | 332 | @torch.no_grad() 333 | def draw_keypoints( 334 | image: torch.Tensor, 335 | keypoints: torch.Tensor, 336 | connectivity: Optional[List[Tuple[int, int]]] = None, 337 | colors: Optional[Union[str, Tuple[int, int, int]]] = None, 338 | radius: int = 2, 339 | width: int = 3, 340 | ) -> torch.Tensor: 341 | 342 | """ 343 | Draws Keypoints on given RGB image. 344 | The values of the input image should be uint8 between 0 and 255. 345 | 346 | Args: 347 | image (Tensor): Tensor of shape (3, H, W) and dtype uint8. 348 | keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, 349 | in the format [x, y]. 350 | connectivity (List[Tuple[int, int]]]): A List of tuple where, 351 | each tuple contains pair of keypoints to be connected. 352 | colors (str, Tuple): The color can be represented as 353 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. 354 | radius (int): Integer denoting radius of keypoint. 355 | width (int): Integer denoting width of line connecting keypoints. 356 | 357 | Returns: 358 | img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. 359 | """ 360 | 361 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 362 | _log_api_usage_once(draw_keypoints) 363 | if not isinstance(image, torch.Tensor): 364 | raise TypeError(f"The image must be a tensor, got {type(image)}") 365 | elif image.dtype != torch.uint8: 366 | raise ValueError(f"The image dtype must be uint8, got {image.dtype}") 367 | elif image.dim() != 3: 368 | raise ValueError("Pass individual images, not batches") 369 | elif image.size()[0] != 3: 370 | raise ValueError("Pass an RGB image. Other Image formats are not supported") 371 | 372 | if keypoints.ndim != 3: 373 | raise ValueError("keypoints must be of shape (num_instances, K, 2)") 374 | 375 | ndarr = image.permute(1, 2, 0).cpu().numpy() 376 | img_to_draw = Image.fromarray(ndarr) 377 | draw = ImageDraw.Draw(img_to_draw) 378 | img_kpts = keypoints.to(torch.int64).tolist() 379 | 380 | for kpt_id, kpt_inst in enumerate(img_kpts): 381 | for inst_id, kpt in enumerate(kpt_inst): 382 | x1 = kpt[0] - radius 383 | x2 = kpt[0] + radius 384 | y1 = kpt[1] - radius 385 | y2 = kpt[1] + radius 386 | draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) 387 | 388 | if connectivity: 389 | for connection in connectivity: 390 | start_pt_x = kpt_inst[connection[0]][0] 391 | start_pt_y = kpt_inst[connection[0]][1] 392 | 393 | end_pt_x = kpt_inst[connection[1]][0] 394 | end_pt_y = kpt_inst[connection[1]][1] 395 | 396 | draw.line( 397 | ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), 398 | width=width, 399 | ) 400 | 401 | return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) 402 | 403 | 404 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization 405 | @torch.no_grad() 406 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor: 407 | 408 | """ 409 | Converts a flow to an RGB image. 410 | 411 | Args: 412 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. 413 | 414 | Returns: 415 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds 416 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. 417 | """ 418 | 419 | if flow.dtype != torch.float: 420 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") 421 | 422 | orig_shape = flow.shape 423 | if flow.ndim == 3: 424 | flow = flow[None] # Add batch dim 425 | 426 | if flow.ndim != 4 or flow.shape[1] != 2: 427 | raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") 428 | 429 | max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() 430 | epsilon = torch.finfo((flow).dtype).eps 431 | normalized_flow = flow / (max_norm + epsilon) 432 | img = _normalized_flow_to_image(normalized_flow) 433 | 434 | if len(orig_shape) == 3: 435 | img = img[0] # Remove batch dim 436 | return img 437 | 438 | 439 | @torch.no_grad() 440 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: 441 | 442 | """ 443 | Converts a batch of normalized flow to an RGB image. 444 | 445 | Args: 446 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) 447 | Returns: 448 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. 449 | """ 450 | 451 | N, _, H, W = normalized_flow.shape 452 | device = normalized_flow.device 453 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) 454 | colorwheel = _make_colorwheel().to(device) # shape [55x3] 455 | num_cols = colorwheel.shape[0] 456 | norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() 457 | a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi 458 | fk = (a + 1) / 2 * (num_cols - 1) 459 | k0 = torch.floor(fk).to(torch.long) 460 | k1 = k0 + 1 461 | k1[k1 == num_cols] = 0 462 | f = fk - k0 463 | 464 | for c in range(colorwheel.shape[1]): 465 | tmp = colorwheel[:, c] 466 | col0 = tmp[k0] / 255.0 467 | col1 = tmp[k1] / 255.0 468 | col = (1 - f) * col0 + f * col1 469 | col = 1 - norm * (1 - col) 470 | flow_image[:, c, :, :] = torch.floor(255 * col) 471 | return flow_image 472 | 473 | 474 | def _make_colorwheel() -> torch.Tensor: 475 | """ 476 | Generates a color wheel for optical flow visualization as presented in: 477 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 478 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. 479 | 480 | Returns: 481 | colorwheel (Tensor[55, 3]): Colorwheel Tensor. 482 | """ 483 | 484 | RY = 15 485 | YG = 6 486 | GC = 4 487 | CB = 11 488 | BM = 13 489 | MR = 6 490 | 491 | ncols = RY + YG + GC + CB + BM + MR 492 | colorwheel = torch.zeros((ncols, 3)) 493 | col = 0 494 | 495 | # RY 496 | colorwheel[0:RY, 0] = 255 497 | colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) 498 | col = col + RY 499 | # YG 500 | colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) 501 | colorwheel[col : col + YG, 1] = 255 502 | col = col + YG 503 | # GC 504 | colorwheel[col : col + GC, 1] = 255 505 | colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) 506 | col = col + GC 507 | # CB 508 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) 509 | colorwheel[col : col + CB, 2] = 255 510 | col = col + CB 511 | # BM 512 | colorwheel[col : col + BM, 2] = 255 513 | colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) 514 | col = col + BM 515 | # MR 516 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) 517 | colorwheel[col : col + MR, 0] = 255 518 | return colorwheel 519 | 520 | 521 | def _generate_color_palette(num_objects: int): 522 | palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) 523 | return [tuple((i * palette) % 255) for i in range(num_objects)] 524 | 525 | 526 | def _log_api_usage_once(obj: Any) -> None: 527 | 528 | """ 529 | Logs API usage(module and name) within an organization. 530 | In a large ecosystem, it's often useful to track the PyTorch and 531 | TorchVision APIs usage. This API provides the similar functionality to the 532 | logging module in the Python stdlib. It can be used for debugging purpose 533 | to log which methods are used and by default it is inactive, unless the user 534 | manually subscribes a logger via the `SetAPIUsageLogger method `_. 535 | Please note it is triggered only once for the same API call within a process. 536 | It does not collect any data from open-source users since it is no-op by default. 537 | For more information, please refer to 538 | * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; 539 | * Logging policy: https://github.com/pytorch/vision/issues/5052; 540 | 541 | Args: 542 | obj (class instance or method): an object to extract info from. 543 | """ 544 | if not obj.__module__.startswith("torchvision"): 545 | return 546 | name = obj.__class__.__name__ 547 | if isinstance(obj, FunctionType): 548 | name = obj.__name__ 549 | torch._C._log_api_usage_once(f"{obj.__module__}.{name}") 550 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/networks/__init__.py -------------------------------------------------------------------------------- /networks/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | from torch.optim import lr_scheduler 6 | 7 | 8 | class BaseModel(nn.Module): 9 | def __init__(self, opt): 10 | super(BaseModel, self).__init__() 11 | self.opt = opt 12 | self.total_steps = 0 13 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 14 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 15 | 16 | def save_networks(self, save_filename): 17 | save_path = os.path.join(self.save_dir, save_filename) 18 | 19 | # serialize model and optimizer to dict 20 | state_dict = { 21 | 'model': self.model.state_dict(), 22 | 'optimizer' : self.optimizer.state_dict(), 23 | 'total_steps' : self.total_steps, 24 | } 25 | 26 | torch.save(state_dict, save_path) 27 | 28 | 29 | def eval(self): 30 | self.model.eval() 31 | 32 | def test(self): 33 | with torch.no_grad(): 34 | self.forward() 35 | 36 | 37 | def init_weights(net, init_type='normal', gain=0.02): 38 | def init_func(m): 39 | classname = m.__class__.__name__ 40 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 41 | if init_type == 'normal': 42 | init.normal_(m.weight.data, 0.0, gain) 43 | elif init_type == 'xavier': 44 | init.xavier_normal_(m.weight.data, gain=gain) 45 | elif init_type == 'kaiming': 46 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 47 | elif init_type == 'orthogonal': 48 | init.orthogonal_(m.weight.data, gain=gain) 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | elif classname.find('BatchNorm2d') != -1: 54 | init.normal_(m.weight.data, 1.0, gain) 55 | init.constant_(m.bias.data, 0.0) 56 | 57 | print('initialize network with %s' % init_type) 58 | net.apply(init_func) 59 | -------------------------------------------------------------------------------- /networks/lpf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, Adobe Inc. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4 | # 4.0 International Public License. To view a copy of this license, visit 5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 6 | 7 | import torch 8 | import torch.nn.parallel 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from IPython import embed 13 | 14 | class Downsample(nn.Module): 15 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 16 | super(Downsample, self).__init__() 17 | self.filt_size = filt_size 18 | self.pad_off = pad_off 19 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] 20 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] 21 | self.stride = stride 22 | self.off = int((self.stride-1)/2.) 23 | self.channels = channels 24 | 25 | # print('Filter size [%i]'%filt_size) 26 | if(self.filt_size==1): 27 | a = np.array([1.,]) 28 | elif(self.filt_size==2): 29 | a = np.array([1., 1.]) 30 | elif(self.filt_size==3): 31 | a = np.array([1., 2., 1.]) 32 | elif(self.filt_size==4): 33 | a = np.array([1., 3., 3., 1.]) 34 | elif(self.filt_size==5): 35 | a = np.array([1., 4., 6., 4., 1.]) 36 | elif(self.filt_size==6): 37 | a = np.array([1., 5., 10., 10., 5., 1.]) 38 | elif(self.filt_size==7): 39 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 40 | 41 | filt = torch.Tensor(a[:,None]*a[None,:]) 42 | filt = filt/torch.sum(filt) 43 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) 44 | 45 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 46 | 47 | def forward(self, inp): 48 | if(self.filt_size==1): 49 | if(self.pad_off==0): 50 | return inp[:,:,::self.stride,::self.stride] 51 | else: 52 | return self.pad(inp)[:,:,::self.stride,::self.stride] 53 | else: 54 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 55 | 56 | def get_pad_layer(pad_type): 57 | if(pad_type in ['refl','reflect']): 58 | PadLayer = nn.ReflectionPad2d 59 | elif(pad_type in ['repl','replicate']): 60 | PadLayer = nn.ReplicationPad2d 61 | elif(pad_type=='zero'): 62 | PadLayer = nn.ZeroPad2d 63 | else: 64 | print('Pad type [%s] not recognized'%pad_type) 65 | return PadLayer 66 | 67 | 68 | class Downsample1D(nn.Module): 69 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 70 | super(Downsample1D, self).__init__() 71 | self.filt_size = filt_size 72 | self.pad_off = pad_off 73 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] 74 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 75 | self.stride = stride 76 | self.off = int((self.stride - 1) / 2.) 77 | self.channels = channels 78 | 79 | # print('Filter size [%i]' % filt_size) 80 | if(self.filt_size == 1): 81 | a = np.array([1., ]) 82 | elif(self.filt_size == 2): 83 | a = np.array([1., 1.]) 84 | elif(self.filt_size == 3): 85 | a = np.array([1., 2., 1.]) 86 | elif(self.filt_size == 4): 87 | a = np.array([1., 3., 3., 1.]) 88 | elif(self.filt_size == 5): 89 | a = np.array([1., 4., 6., 4., 1.]) 90 | elif(self.filt_size == 6): 91 | a = np.array([1., 5., 10., 10., 5., 1.]) 92 | elif(self.filt_size == 7): 93 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 94 | 95 | filt = torch.Tensor(a) 96 | filt = filt / torch.sum(filt) 97 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) 98 | 99 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) 100 | 101 | def forward(self, inp): 102 | if(self.filt_size == 1): 103 | if(self.pad_off == 0): 104 | return inp[:, :, ::self.stride] 105 | else: 106 | return self.pad(inp)[:, :, ::self.stride] 107 | else: 108 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 109 | 110 | 111 | def get_pad_layer_1d(pad_type): 112 | if(pad_type in ['refl', 'reflect']): 113 | PadLayer = nn.ReflectionPad1d 114 | elif(pad_type in ['repl', 'replicate']): 115 | PadLayer = nn.ReplicationPad1d 116 | elif(pad_type == 'zero'): 117 | PadLayer = nn.ZeroPad1d 118 | else: 119 | print('Pad type [%s] not recognized' % pad_type) 120 | return PadLayer 121 | -------------------------------------------------------------------------------- /networks/resnet_lpf.py: -------------------------------------------------------------------------------- 1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. 2 | # Copyright (c) 2017 Torch Contributors. 3 | # The Pytorch examples are available under the BSD 3-Clause License. 4 | # 5 | # ========================================================================================== 6 | # 7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. 8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit 10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 11 | # 12 | # ========================================================================================== 13 | # 14 | # BSD-3 License 15 | # 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | # 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | # 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | # 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | # 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | 40 | import torch.nn as nn 41 | import torch.utils.model_zoo as model_zoo 42 | from .lpf import * 43 | 44 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 45 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 46 | 47 | 48 | # model_urls = { 49 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 50 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 51 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 52 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 53 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 54 | # } 55 | 56 | 57 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 58 | """3x3 convolution with padding""" 59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 60 | padding=1, groups=groups, bias=False) 61 | 62 | def conv1x1(in_planes, out_planes, stride=1): 63 | """1x1 convolution""" 64 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 65 | 66 | class BasicBlock(nn.Module): 67 | expansion = 1 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): 70 | super(BasicBlock, self).__init__() 71 | if norm_layer is None: 72 | norm_layer = nn.BatchNorm2d 73 | if groups != 1: 74 | raise ValueError('BasicBlock only supports groups=1') 75 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 76 | self.conv1 = conv3x3(inplanes, planes) 77 | self.bn1 = norm_layer(planes) 78 | self.relu = nn.ReLU(inplace=True) 79 | if(stride==1): 80 | self.conv2 = conv3x3(planes,planes) 81 | else: 82 | self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), 83 | conv3x3(planes, planes),) 84 | self.bn2 = norm_layer(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | identity = self.downsample(x) 100 | 101 | out += identity 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | expansion = 4 109 | 110 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): 111 | super(Bottleneck, self).__init__() 112 | if norm_layer is None: 113 | norm_layer = nn.BatchNorm2d 114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 115 | self.conv1 = conv1x1(inplanes, planes) 116 | self.bn1 = norm_layer(planes) 117 | self.conv2 = conv3x3(planes, planes, groups) # stride moved 118 | self.bn2 = norm_layer(planes) 119 | if(stride==1): 120 | self.conv3 = conv1x1(planes, planes * self.expansion) 121 | else: 122 | self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), 123 | conv1x1(planes, planes * self.expansion)) 124 | self.bn3 = norm_layer(planes * self.expansion) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.downsample = downsample 127 | self.stride = stride 128 | 129 | def forward(self, x): 130 | identity = x 131 | 132 | out = self.conv1(x) 133 | out = self.bn1(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv2(out) 137 | out = self.bn2(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv3(out) 141 | out = self.bn3(out) 142 | 143 | if self.downsample is not None: 144 | identity = self.downsample(x) 145 | 146 | out += identity 147 | out = self.relu(out) 148 | 149 | return out 150 | 151 | 152 | class ResNet(nn.Module): 153 | 154 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 155 | groups=1, width_per_group=64, norm_layer=None, filter_size=1, pool_only=True): 156 | super(ResNet, self).__init__() 157 | if norm_layer is None: 158 | norm_layer = nn.BatchNorm2d 159 | planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] 160 | self.inplanes = planes[0] 161 | 162 | if(pool_only): 163 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False) 164 | else: 165 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False) 166 | self.bn1 = norm_layer(planes[0]) 167 | self.relu = nn.ReLU(inplace=True) 168 | 169 | if(pool_only): 170 | self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), 171 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) 172 | else: 173 | self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]), 174 | nn.MaxPool2d(kernel_size=2, stride=1), 175 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) 176 | 177 | self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer) 178 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 179 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 180 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 182 | self.fc = nn.Linear(planes[3] * block.expansion, num_classes) 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None): 187 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics 188 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 189 | else: 190 | print('Not initializing') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) 204 | 205 | def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1): 206 | if norm_layer is None: 207 | norm_layer = nn.BatchNorm2d 208 | downsample = None 209 | if stride != 1 or self.inplanes != planes * block.expansion: 210 | # downsample = nn.Sequential( 211 | # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size), 212 | # norm_layer(planes * block.expansion), 213 | # ) 214 | 215 | downsample = [Downsample(filt_size=filter_size, stride=stride, channels=self.inplanes),] if(stride !=1) else [] 216 | downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), 217 | norm_layer(planes * block.expansion)] 218 | # print(downsample) 219 | downsample = nn.Sequential(*downsample) 220 | 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, groups, norm_layer, filter_size=filter_size)) 223 | self.inplanes = planes * block.expansion 224 | for _ in range(1, blocks): 225 | layers.append(block(self.inplanes, planes, groups=groups, norm_layer=norm_layer, filter_size=filter_size)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.bn1(x) 232 | x = self.relu(x) 233 | x = self.maxpool(x) 234 | 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | 240 | x = self.avgpool(x) 241 | x = x.view(x.size(0), -1) 242 | x = self.fc(x) 243 | 244 | return x 245 | 246 | 247 | def resnet18(pretrained=False, filter_size=1, pool_only=True, **kwargs): 248 | """Constructs a ResNet-18 model. 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | """ 252 | model = ResNet(BasicBlock, [2, 2, 2, 2], filter_size=filter_size, pool_only=pool_only, **kwargs) 253 | if pretrained: 254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 255 | return model 256 | 257 | 258 | def resnet34(pretrained=False, filter_size=1, pool_only=True, **kwargs): 259 | """Constructs a ResNet-34 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | """ 263 | model = ResNet(BasicBlock, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 264 | if pretrained: 265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 266 | return model 267 | 268 | 269 | def resnet50(pretrained=False, filter_size=1, pool_only=True, **kwargs): 270 | """Constructs a ResNet-50 model. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | """ 274 | model = ResNet(Bottleneck, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 275 | if pretrained: 276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 277 | return model 278 | 279 | 280 | def resnet101(pretrained=False, filter_size=1, pool_only=True, **kwargs): 281 | """Constructs a ResNet-101 model. 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | """ 285 | model = ResNet(Bottleneck, [3, 4, 23, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 286 | if pretrained: 287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 288 | return model 289 | 290 | 291 | def resnet152(pretrained=False, filter_size=1, pool_only=True, **kwargs): 292 | """Constructs a ResNet-152 model. 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | """ 296 | model = ResNet(Bottleneck, [3, 8, 36, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 297 | if pretrained: 298 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 299 | return model 300 | 301 | 302 | def resnext50_32x4d(pretrained=False, filter_size=1, pool_only=True, **kwargs): 303 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) 304 | # if pretrained: 305 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 306 | return model 307 | 308 | 309 | def resnext101_32x8d(pretrained=False, filter_size=1, pool_only=True, **kwargs): 310 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) 311 | # if pretrained: 312 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 313 | return model 314 | -------------------------------------------------------------------------------- /networks/trainer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | from networks.base_model import BaseModel, init_weights 5 | import sys 6 | from models import get_model 7 | 8 | class Trainer(BaseModel): 9 | def name(self): 10 | return 'Trainer' 11 | 12 | def __init__(self, opt): 13 | super(Trainer, self).__init__(opt) 14 | self.opt = opt 15 | self.model = get_model(opt.arch) 16 | torch.nn.init.normal_(self.model.fc.weight.data, 0.0, opt.init_gain) 17 | 18 | if opt.fix_backbone: 19 | params = [] 20 | for name, p in self.model.named_parameters(): 21 | if name=="fc.weight" or name=="fc.bias": 22 | params.append(p) 23 | else: 24 | p.requires_grad = False 25 | else: 26 | print("Your backbone is not fixed. Are you sure you want to proceed? If this is a mistake, enable the --fix_backbone command during training and rerun") 27 | import time 28 | time.sleep(3) 29 | params = self.model.parameters() 30 | 31 | 32 | 33 | if opt.optim == 'adam': 34 | self.optimizer = torch.optim.AdamW(params, lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) 35 | elif opt.optim == 'sgd': 36 | self.optimizer = torch.optim.SGD(params, lr=opt.lr, momentum=0.0, weight_decay=opt.weight_decay) 37 | else: 38 | raise ValueError("optim should be [adam, sgd]") 39 | 40 | self.loss_fn = nn.BCEWithLogitsLoss() 41 | 42 | self.model.to(opt.gpu_ids[0]) 43 | 44 | 45 | def adjust_learning_rate(self, min_lr=1e-6): 46 | for param_group in self.optimizer.param_groups: 47 | param_group['lr'] /= 10. 48 | if param_group['lr'] < min_lr: 49 | return False 50 | return True 51 | 52 | 53 | def set_input(self, input): 54 | self.input = input[0].to(self.device) 55 | self.label = input[1].to(self.device).float() 56 | 57 | 58 | def forward(self): 59 | self.output = self.model(self.input) 60 | self.output = self.output.view(-1).unsqueeze(1) 61 | 62 | 63 | def get_loss(self): 64 | return self.loss_fn(self.output.squeeze(1), self.label) 65 | 66 | def optimize_parameters(self): 67 | self.forward() 68 | self.loss = self.loss_fn(self.output.squeeze(1), self.label) 69 | self.optimizer.zero_grad() 70 | self.loss.backward() 71 | self.optimizer.step() 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import util 4 | import torch 5 | 6 | 7 | class BaseOptions(): 8 | def __init__(self): 9 | self.initialized = False 10 | 11 | def initialize(self, parser): 12 | parser.add_argument('--mode', default='binary') 13 | parser.add_argument('--arch', type=str, default='res50', help='see my_models/__init__.py') 14 | parser.add_argument('--fix_backbone', action='store_true') 15 | 16 | # data augmentation 17 | parser.add_argument('--rz_interp', default='bilinear') 18 | parser.add_argument('--blur_prob', type=float, default=0.5) 19 | parser.add_argument('--blur_sig', default='0.0,3.0') 20 | parser.add_argument('--jpg_prob', type=float, default=0.5) 21 | parser.add_argument('--jpg_method', default='cv2,pil') 22 | parser.add_argument('--jpg_qual', default='30,100') 23 | 24 | 25 | parser.add_argument('--real_list_path', default=None, help='only used if data_mode==ours: path for the list of real images, which should contain train.pickle and val.pickle') 26 | parser.add_argument('--fake_list_path', default=None, help='only used if data_mode==ours: path for the list of fake images, which should contain train.pickle and val.pickle') 27 | parser.add_argument('--wang2020_data_path', default=None, help='only used if data_mode==wang2020 it should contain train and test folders') 28 | parser.add_argument('--data_mode', default='ours', help='wang2020 or ours') 29 | parser.add_argument('--data_label', default='train', help='label to decide whether train or validation dataset') 30 | parser.add_argument('--weight_decay', type=float, default=0.0, help='loss weight for l2 reg') 31 | 32 | parser.add_argument('--class_bal', action='store_true') # what is this ? 33 | parser.add_argument('--batch_size', type=int, default=256, help='input batch size') 34 | parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size') 35 | parser.add_argument('--cropSize', type=int, default=224, help='then crop to this size') 36 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 37 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 38 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 39 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 40 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 41 | parser.add_argument('--resize_or_crop', type=str, default='scale_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]') 42 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 43 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 44 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 45 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') 46 | self.initialized = True 47 | return parser 48 | 49 | def gather_options(self): 50 | # initialize parser with basic options 51 | if not self.initialized: 52 | parser = argparse.ArgumentParser( 53 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 54 | parser = self.initialize(parser) 55 | 56 | # get the basic options 57 | opt, _ = parser.parse_known_args() 58 | self.parser = parser 59 | 60 | return parser.parse_args() 61 | 62 | def print_options(self, opt): 63 | message = '' 64 | message += '----------------- Options ---------------\n' 65 | for k, v in sorted(vars(opt).items()): 66 | comment = '' 67 | default = self.parser.get_default(k) 68 | if v != default: 69 | comment = '\t[default: %s]' % str(default) 70 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 71 | message += '----------------- End -------------------' 72 | print(message) 73 | 74 | # save to the disk 75 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 76 | util.mkdirs(expr_dir) 77 | file_name = os.path.join(expr_dir, 'opt.txt') 78 | with open(file_name, 'wt') as opt_file: 79 | opt_file.write(message) 80 | opt_file.write('\n') 81 | 82 | def parse(self, print_options=True): 83 | 84 | opt = self.gather_options() 85 | opt.isTrain = self.isTrain # train or test 86 | 87 | # process opt.suffix 88 | if opt.suffix: 89 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 90 | opt.name = opt.name + suffix 91 | 92 | if print_options: 93 | self.print_options(opt) 94 | 95 | # set gpu ids 96 | str_ids = opt.gpu_ids.split(',') 97 | opt.gpu_ids = [] 98 | for str_id in str_ids: 99 | id = int(str_id) 100 | if id >= 0: 101 | opt.gpu_ids.append(id) 102 | if len(opt.gpu_ids) > 0: 103 | torch.cuda.set_device(opt.gpu_ids[0]) 104 | 105 | # additional 106 | #opt.classes = opt.classes.split(',') 107 | opt.rz_interp = opt.rz_interp.split(',') 108 | opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')] 109 | opt.jpg_method = opt.jpg_method.split(',') 110 | opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')] 111 | if len(opt.jpg_qual) == 2: 112 | opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1)) 113 | elif len(opt.jpg_qual) > 2: 114 | raise ValueError("Shouldn't have more than 2 values for --jpg_qual.") 115 | 116 | self.opt = opt 117 | return self.opt 118 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--model_path') 8 | parser.add_argument('--no_resize', action='store_true') 9 | parser.add_argument('--no_crop', action='store_true') 10 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 11 | 12 | self.isTrain = False 13 | return parser 14 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--earlystop_epoch', type=int, default=5) 8 | parser.add_argument('--data_aug', action='store_true', help='if specified, perform additional data augmentation (photometric, blurring, jpegging)') 9 | parser.add_argument('--optim', type=str, default='adam', help='optim to use [sgd, adam]') 10 | parser.add_argument('--new_optim', action='store_true', help='new optimizer instead of loading the optim state') 11 | parser.add_argument('--loss_freq', type=int, default=400, help='frequency of showing loss on tensorboard') 12 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 13 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 14 | parser.add_argument('--last_epoch', type=int, default=-1, help='starting epoch count for scheduler intialization') 15 | parser.add_argument('--train_split', type=str, default='train', help='train, val, test, etc') 16 | parser.add_argument('--val_split', type=str, default='val', help='train, val, test, etc') 17 | parser.add_argument('--niter', type=int, default=100, help='total epoches') 18 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') 19 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 20 | 21 | self.isTrain = True 22 | return parser 23 | -------------------------------------------------------------------------------- /pretrained_weights/fc_weights.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/pretrained_weights/fc_weights.pth -------------------------------------------------------------------------------- /resources/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/resources/teaser.png -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 validate.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --result_folder=clip_vitl14 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from tensorboardX import SummaryWriter 4 | 5 | from validate import validate 6 | from data import create_dataloader 7 | from earlystop import EarlyStopping 8 | from networks.trainer import Trainer 9 | from options.train_options import TrainOptions 10 | 11 | 12 | """Currently assumes jpg_prob, blur_prob 0 or 1""" 13 | def get_val_opt(): 14 | val_opt = TrainOptions().parse(print_options=False) 15 | val_opt.isTrain = False 16 | val_opt.no_resize = False 17 | val_opt.no_crop = False 18 | val_opt.serial_batches = True 19 | val_opt.data_label = 'val' 20 | val_opt.jpg_method = ['pil'] 21 | if len(val_opt.blur_sig) == 2: 22 | b_sig = val_opt.blur_sig 23 | val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2] 24 | if len(val_opt.jpg_qual) != 1: 25 | j_qual = val_opt.jpg_qual 26 | val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)] 27 | 28 | return val_opt 29 | 30 | 31 | 32 | if __name__ == '__main__': 33 | opt = TrainOptions().parse() 34 | val_opt = get_val_opt() 35 | 36 | model = Trainer(opt) 37 | 38 | data_loader = create_dataloader(opt) 39 | val_loader = create_dataloader(val_opt) 40 | 41 | train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train")) 42 | val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val")) 43 | 44 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True) 45 | start_time = time.time() 46 | print ("Length of data loader: %d" %(len(data_loader))) 47 | for epoch in range(opt.niter): 48 | 49 | for i, data in enumerate(data_loader): 50 | model.total_steps += 1 51 | 52 | model.set_input(data) 53 | model.optimize_parameters() 54 | 55 | if model.total_steps % opt.loss_freq == 0: 56 | print("Train loss: {} at step: {}".format(model.loss, model.total_steps)) 57 | train_writer.add_scalar('loss', model.loss, model.total_steps) 58 | print("Iter time: ", ((time.time()-start_time)/model.total_steps) ) 59 | 60 | if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters 61 | model.save_networks('model_iters_%s.pth' % model.total_steps) 62 | 63 | if epoch % opt.save_epoch_freq == 0: 64 | print('saving the model at the end of epoch %d' % (epoch)) 65 | model.save_networks( 'model_epoch_best.pth' ) 66 | model.save_networks( 'model_epoch_%s.pth' % epoch ) 67 | 68 | # Validation 69 | model.eval() 70 | ap, r_acc, f_acc, acc = validate(model.model, val_loader) 71 | val_writer.add_scalar('accuracy', acc, model.total_steps) 72 | val_writer.add_scalar('ap', ap, model.total_steps) 73 | print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap)) 74 | 75 | early_stopping(acc, model) 76 | if early_stopping.early_stop: 77 | cont_train = model.adjust_learning_rate() 78 | if cont_train: 79 | print("Learning rate dropped by 10, continue training...") 80 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.002, verbose=True) 81 | else: 82 | print("Early stopping.") 83 | break 84 | model.train() 85 | 86 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ast import arg 3 | import os 4 | import csv 5 | import torch 6 | import torchvision.transforms as transforms 7 | import torch.utils.data 8 | import numpy as np 9 | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score 10 | from torch.utils.data import Dataset 11 | import sys 12 | from models import get_model 13 | from PIL import Image 14 | import pickle 15 | from tqdm import tqdm 16 | from io import BytesIO 17 | from copy import deepcopy 18 | from dataset_paths import DATASET_PATHS 19 | import random 20 | import shutil 21 | from scipy.ndimage.filters import gaussian_filter 22 | 23 | SEED = 0 24 | def set_seed(): 25 | torch.manual_seed(SEED) 26 | torch.cuda.manual_seed(SEED) 27 | np.random.seed(SEED) 28 | random.seed(SEED) 29 | 30 | 31 | MEAN = { 32 | "imagenet":[0.485, 0.456, 0.406], 33 | "clip":[0.48145466, 0.4578275, 0.40821073] 34 | } 35 | 36 | STD = { 37 | "imagenet":[0.229, 0.224, 0.225], 38 | "clip":[0.26862954, 0.26130258, 0.27577711] 39 | } 40 | 41 | 42 | 43 | 44 | 45 | def find_best_threshold(y_true, y_pred): 46 | "We assume first half is real 0, and the second half is fake 1" 47 | 48 | N = y_true.shape[0] 49 | 50 | if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case 51 | return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 52 | 53 | best_acc = 0 54 | best_thres = 0 55 | for thres in y_pred: 56 | temp = deepcopy(y_pred) 57 | temp[temp>=thres] = 1 58 | temp[temp= best_acc: 62 | best_thres = thres 63 | best_acc = acc 64 | 65 | return best_thres 66 | 67 | 68 | 69 | def png2jpg(img, quality): 70 | out = BytesIO() 71 | img.save(out, format='jpeg', quality=quality) # ranging from 0-95, 75 is default 72 | img = Image.open(out) 73 | # load from memory before ByteIO closes 74 | img = np.array(img) 75 | out.close() 76 | return Image.fromarray(img) 77 | 78 | 79 | def gaussian_blur(img, sigma): 80 | img = np.array(img) 81 | 82 | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) 83 | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) 84 | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) 85 | 86 | return Image.fromarray(img) 87 | 88 | 89 | 90 | def calculate_acc(y_true, y_pred, thres): 91 | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) 92 | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) 93 | acc = accuracy_score(y_true, y_pred > thres) 94 | return r_acc, f_acc, acc 95 | 96 | 97 | def validate(model, loader, find_thres=False): 98 | 99 | with torch.no_grad(): 100 | y_true, y_pred = [], [] 101 | print ("Length of dataset: %d" %(len(loader))) 102 | for img, label in loader: 103 | in_tens = img.cuda() 104 | 105 | y_pred.extend(model(in_tens).sigmoid().flatten().tolist()) 106 | y_true.extend(label.flatten().tolist()) 107 | 108 | y_true, y_pred = np.array(y_true), np.array(y_pred) 109 | 110 | # ================== save this if you want to plot the curves =========== # 111 | # torch.save( torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' ) 112 | # exit() 113 | # =================================================================== # 114 | 115 | # Get AP 116 | ap = average_precision_score(y_true, y_pred) 117 | 118 | # Acc based on 0.5 119 | r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) 120 | if not find_thres: 121 | return ap, r_acc0, f_acc0, acc0 122 | 123 | 124 | # Acc based on the best thres 125 | best_thres = find_best_threshold(y_true, y_pred) 126 | r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) 127 | 128 | return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres 129 | 130 | 131 | 132 | 133 | 134 | 135 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 136 | 137 | 138 | 139 | 140 | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): 141 | out = [] 142 | for r, d, f in os.walk(rootdir): 143 | for file in f: 144 | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): 145 | out.append(os.path.join(r, file)) 146 | return out 147 | 148 | 149 | def get_list(path, must_contain=''): 150 | if ".pickle" in path: 151 | with open(path, 'rb') as f: 152 | image_list = pickle.load(f) 153 | image_list = [ item for item in image_list if must_contain in item ] 154 | else: 155 | image_list = recursively_read(path, must_contain) 156 | return image_list 157 | 158 | 159 | 160 | 161 | 162 | class RealFakeDataset(Dataset): 163 | def __init__(self, real_path, 164 | fake_path, 165 | data_mode, 166 | max_sample, 167 | arch, 168 | jpeg_quality=None, 169 | gaussian_sigma=None): 170 | 171 | assert data_mode in ["wang2020", "ours"] 172 | self.jpeg_quality = jpeg_quality 173 | self.gaussian_sigma = gaussian_sigma 174 | 175 | # = = = = = = data path = = = = = = = = = # 176 | if type(real_path) == str and type(fake_path) == str: 177 | real_list, fake_list = self.read_path(real_path, fake_path, data_mode, max_sample) 178 | else: 179 | real_list = [] 180 | fake_list = [] 181 | for real_p, fake_p in zip(real_path, fake_path): 182 | real_l, fake_l = self.read_path(real_p, fake_p, data_mode, max_sample) 183 | real_list += real_l 184 | fake_list += fake_l 185 | 186 | self.total_list = real_list + fake_list 187 | 188 | 189 | # = = = = = = label = = = = = = = = = # 190 | 191 | self.labels_dict = {} 192 | for i in real_list: 193 | self.labels_dict[i] = 0 194 | for i in fake_list: 195 | self.labels_dict[i] = 1 196 | 197 | stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip" 198 | self.transform = transforms.Compose([ 199 | transforms.CenterCrop(224), 200 | transforms.ToTensor(), 201 | transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), 202 | ]) 203 | 204 | 205 | def read_path(self, real_path, fake_path, data_mode, max_sample): 206 | 207 | if data_mode == 'wang2020': 208 | real_list = get_list(real_path, must_contain='0_real') 209 | fake_list = get_list(fake_path, must_contain='1_fake') 210 | else: 211 | real_list = get_list(real_path) 212 | fake_list = get_list(fake_path) 213 | 214 | 215 | if max_sample is not None: 216 | if (max_sample > len(real_list)) or (max_sample > len(fake_list)): 217 | max_sample = 100 218 | print("not enough images, max_sample falling to 100") 219 | random.shuffle(real_list) 220 | random.shuffle(fake_list) 221 | real_list = real_list[0:max_sample] 222 | fake_list = fake_list[0:max_sample] 223 | 224 | assert len(real_list) == len(fake_list) 225 | 226 | return real_list, fake_list 227 | 228 | 229 | 230 | def __len__(self): 231 | return len(self.total_list) 232 | 233 | def __getitem__(self, idx): 234 | 235 | img_path = self.total_list[idx] 236 | 237 | label = self.labels_dict[img_path] 238 | img = Image.open(img_path).convert("RGB") 239 | 240 | if self.gaussian_sigma is not None: 241 | img = gaussian_blur(img, self.gaussian_sigma) 242 | if self.jpeg_quality is not None: 243 | img = png2jpg(img, self.jpeg_quality) 244 | 245 | img = self.transform(img) 246 | return img, label 247 | 248 | 249 | 250 | 251 | 252 | if __name__ == '__main__': 253 | 254 | 255 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 256 | parser.add_argument('--real_path', type=str, default=None, help='dir name or a pickle') 257 | parser.add_argument('--fake_path', type=str, default=None, help='dir name or a pickle') 258 | parser.add_argument('--data_mode', type=str, default=None, help='wang2020 or ours') 259 | parser.add_argument('--max_sample', type=int, default=1000, help='only check this number of images for both fake/real') 260 | 261 | parser.add_argument('--arch', type=str, default='res50') 262 | parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth') 263 | 264 | parser.add_argument('--result_folder', type=str, default='result', help='') 265 | parser.add_argument('--batch_size', type=int, default=128) 266 | 267 | parser.add_argument('--jpeg_quality', type=int, default=None, help="100, 90, 80, ... 30. Used to test robustness of our model. Not apply if None") 268 | parser.add_argument('--gaussian_sigma', type=int, default=None, help="0,1,2,3,4. Used to test robustness of our model. Not apply if None") 269 | 270 | 271 | opt = parser.parse_args() 272 | 273 | 274 | if os.path.exists(opt.result_folder): 275 | shutil.rmtree(opt.result_folder) 276 | os.makedirs(opt.result_folder) 277 | 278 | model = get_model(opt.arch) 279 | state_dict = torch.load(opt.ckpt, map_location='cpu') 280 | model.fc.load_state_dict(state_dict) 281 | print ("Model loaded..") 282 | model.eval() 283 | model.cuda() 284 | 285 | if (opt.real_path == None) or (opt.fake_path == None) or (opt.data_mode == None): 286 | dataset_paths = DATASET_PATHS 287 | else: 288 | dataset_paths = [ dict(real_path=opt.real_path, fake_path=opt.fake_path, data_mode=opt.data_mode) ] 289 | 290 | 291 | 292 | for dataset_path in (dataset_paths): 293 | set_seed() 294 | 295 | dataset = RealFakeDataset( dataset_path['real_path'], 296 | dataset_path['fake_path'], 297 | dataset_path['data_mode'], 298 | opt.max_sample, 299 | opt.arch, 300 | jpeg_quality=opt.jpeg_quality, 301 | gaussian_sigma=opt.gaussian_sigma, 302 | ) 303 | 304 | loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=4) 305 | ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, loader, find_thres=True) 306 | 307 | with open( os.path.join(opt.result_folder,'ap.txt'), 'a') as f: 308 | f.write(dataset_path['key']+': ' + str(round(ap*100, 2))+'\n' ) 309 | 310 | with open( os.path.join(opt.result_folder,'acc0.txt'), 'a') as f: 311 | f.write(dataset_path['key']+': ' + str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' ) 312 | 313 | --------------------------------------------------------------------------------