├── LICENSE
├── README.md
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── datasets.cpython-38.pyc
│ └── datasets.cpython-39.pyc
└── datasets.py
├── dataset_paths.py
├── models
├── __init__.py
├── clip
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── clip.cpython-310.pyc
│ │ ├── clip.cpython-38.pyc
│ │ ├── clip.cpython-39.pyc
│ │ ├── model.cpython-310.pyc
│ │ ├── model.cpython-38.pyc
│ │ ├── model.cpython-39.pyc
│ │ ├── simple_tokenizer.cpython-310.pyc
│ │ ├── simple_tokenizer.cpython-38.pyc
│ │ └── simple_tokenizer.cpython-39.pyc
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── model.py
│ └── simple_tokenizer.py
├── clip_models.py
├── imagenet_models.py
├── resnet.py
├── vgg.py
├── vision_transformer.py
├── vision_transformer_misc.py
└── vision_transformer_utils.py
├── networks
├── __init__.py
├── base_model.py
├── lpf.py
├── resnet_lpf.py
└── trainer.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── pretrained_weights
└── fc_weights.pth
├── resources
└── teaser.png
├── test.sh
├── train.py
└── validate.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Wisconsin AI and Vision Lab (WAIV)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Detecting fake images
2 |
3 | **Towards Universal Fake Image Detectors that Generalize Across Generative Models**
4 | [Utkarsh Ojha*](https://utkarshojha.github.io/), [Yuheng Li*](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)
5 | (*Equal contribution)
6 | CVPR 2023
7 |
8 | [[Project Page](https://utkarshojha.github.io/universal-fake-detection/)] [[Paper](https://arxiv.org/abs/2302.10174)]
9 |
10 |
11 |
>
12 | Using images from one type of generative model (e.g., GAN), detect fake images from other breeds (e.g., Diffusion models)
13 |
14 |
15 | ## Contents
16 |
17 | - [Setup](#setup)
18 | - [Pretrained model](#weights)
19 | - [Data](#data)
20 | - [Evaluation](#evaluation)
21 | - [Training](#training)
22 |
23 |
24 | ## Setup
25 |
26 | 1. Clone this repository
27 | ```bash
28 | git clone https://github.com/Yuheng-Li/UniversalFakeDetect
29 | cd UniversalFakeDetect
30 | ```
31 |
32 | 2. Install the necessary libraries
33 | ```bash
34 | pip install torch torchvision
35 | ```
36 |
37 | ## Data
38 |
39 | - Of the 19 models studied overall (Table 1/2 in the main paper), 11 are taken from a [previous work](https://arxiv.org/abs/1912.11035). Download the test set, i.e., real/fake images for those 11 models given by the authors from [here](https://drive.google.com/file/d/1z_fD3UKgWQyOTZIBbYSaQ-hz4AzUrLC1/view) (dataset size ~19GB).
40 | - Download the file and unzip it in `datasets/test`. You could also use the bash scripts provided by the authors, as described [here](https://github.com/PeterWang512/CNNDetection#download-the-dataset) in their code repository.
41 | - This should create a directory structure as follows:
42 | ```
43 |
44 | datasets
45 | └── test
46 | ├── progan
47 | │── cyclegan
48 | │── biggan
49 | │ .
50 | │ .
51 |
52 | ```
53 | - Each directory (e.g., progan) will contain real/fake images under `0_real` and `1_fake` folders respectively.
54 | - Dataset for the diffusion models (e.g., LDM/Glide) can be found [here](https://drive.google.com/file/d/1FXlGIRh_Ud3cScMgSVDbEWmPDmjcrm1t/view?usp=drive_link). Note that in the paper (Table 2/3), we had reported the results over 10k randomly sampled images. Since providing that many images for all the domains will take up too much space, we are only releasing 1k images for each domain; i.e., 1k images fake images and 1k real images for each domain (e.g., LDM-200).
55 | - Download and unzip the file into `./diffusion_datasets` directory.
56 |
57 |
58 | ## Evaluation
59 |
60 | - You can evaluate the model on all the dataset at once by running:
61 | ```bash
62 | python validate.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --result_folder=clip_vitl14
63 | ```
64 |
65 | - You can also evaluate the model on one generative model by specifying the paths of real and fake datasets
66 | ```bash
67 | python validate.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --result_folder=clip_vitl14 --real_path datasets/test/progan/0_real --fake_path datasets/test/progan/1_fake
68 | ```
69 |
70 | Note that if no arguments are provided for `real_path` and `fake_path`, the script will perform the evaluation on all the domains specified in `dataset_paths.py`.
71 |
72 | - The results will be stored in `results/` in two files: `ap.txt` stores the Average Prevision for each of the test domains, and `acc.txt` stores the accuracy (with 0.5 as the threshold) for the same domains.
73 |
74 | ## Training
75 |
76 | - Our main model is trained on the same dataset used by the authors of [this work](https://arxiv.org/abs/1912.11035). Download the official training dataset provided [here](https://drive.google.com/file/d/1iVNBV0glknyTYGA9bCxT_d0CVTOgGcKh/view) (dataset size ~ 72GB).
77 |
78 | - Download and unzip the dataset in `datasets/train` directory. The overall structure should look like the following:
79 | ```
80 | datasets
81 | └── train
82 | └── progan
83 | ├── airplane
84 | │── bird
85 | │── boat
86 | │ .
87 | │ .
88 | ```
89 | - A total of 20 different object categories, with each folder containing the corresponding real and fake images in `0_real` and `1_fake` folders.
90 | - The model can then be trained with the following command:
91 | ```bash
92 | python train.py --name=clip_vitl14 --wang2020_data_path=datasets/ --data_mode=wang2020 --arch=CLIP:ViT-L/14 --fix_backbone
93 | ```
94 | - **Important**: do not forget to use the `--fix_backbone` argument during training, which makes sure that the only the linear layer's parameters will be trained.
95 |
96 | ## Acknowledgement
97 |
98 | We would like to thank [Sheng-Yu Wang](https://github.com/PeterWang512) for releasing the real/fake images from different generative models. Our training pipeline is also inspired by his [open-source code](https://github.com/PeterWang512/CNNDetection). We would also like to thank [CompVis](https://github.com/CompVis) for releasing the pre-trained [LDMs](https://github.com/CompVis/latent-diffusion) and [LAION](https://laion.ai/) for open-sourcing [LAION-400M dataset](https://laion.ai/blog/laion-400-open-dataset/).
99 |
100 | ## Citation
101 |
102 | If you find our work helpful in your research, please cite it using the following:
103 | ```bibtex
104 | @inproceedings{ojha2023fakedetect,
105 | title={Towards Universal Fake Image Detectors that Generalize Across Generative Models},
106 | author={Ojha, Utkarsh and Li, Yuheng and Lee, Yong Jae},
107 | booktitle={CVPR},
108 | year={2023},
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data.sampler import WeightedRandomSampler
4 |
5 | from .datasets import RealFakeDataset
6 |
7 |
8 |
9 | def get_bal_sampler(dataset):
10 | targets = []
11 | for d in dataset.datasets:
12 | targets.extend(d.targets)
13 |
14 | ratio = np.bincount(targets)
15 | w = 1. / torch.tensor(ratio, dtype=torch.float)
16 | sample_weights = w[targets]
17 | sampler = WeightedRandomSampler(weights=sample_weights,
18 | num_samples=len(sample_weights))
19 | return sampler
20 |
21 |
22 | def create_dataloader(opt, preprocess=None):
23 | shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False
24 | dataset = RealFakeDataset(opt)
25 | if '2b' in opt.arch:
26 | dataset.transform = preprocess
27 | sampler = get_bal_sampler(dataset) if opt.class_bal else None
28 |
29 | data_loader = torch.utils.data.DataLoader(dataset,
30 | batch_size=opt.batch_size,
31 | shuffle=shuffle,
32 | sampler=sampler,
33 | num_workers=int(opt.num_threads))
34 | return data_loader
35 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/datasets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/datasets.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/datasets.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/data/__pycache__/datasets.cpython-39.pyc
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torchvision.datasets as datasets
4 | import torchvision.transforms as transforms
5 | import torchvision.transforms.functional as TF
6 | from torch.utils.data import Dataset
7 | from random import random, choice, shuffle
8 | from io import BytesIO
9 | from PIL import Image
10 | from PIL import ImageFile
11 | from scipy.ndimage.filters import gaussian_filter
12 | import pickle
13 | import os
14 | from skimage.io import imread
15 | from copy import deepcopy
16 |
17 | ImageFile.LOAD_TRUNCATED_IMAGES = True
18 |
19 |
20 | MEAN = {
21 | "imagenet":[0.485, 0.456, 0.406],
22 | "clip":[0.48145466, 0.4578275, 0.40821073]
23 | }
24 |
25 | STD = {
26 | "imagenet":[0.229, 0.224, 0.225],
27 | "clip":[0.26862954, 0.26130258, 0.27577711]
28 | }
29 |
30 |
31 |
32 |
33 | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg"]):
34 | out = []
35 | for r, d, f in os.walk(rootdir):
36 | for file in f:
37 | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)):
38 | out.append(os.path.join(r, file))
39 | return out
40 |
41 |
42 | def get_list(path, must_contain=''):
43 | if ".pickle" in path:
44 | with open(path, 'rb') as f:
45 | image_list = pickle.load(f)
46 | image_list = [ item for item in image_list if must_contain in item ]
47 | else:
48 | image_list = recursively_read(path, must_contain)
49 | return image_list
50 |
51 |
52 |
53 |
54 | class RealFakeDataset(Dataset):
55 | def __init__(self, opt):
56 | assert opt.data_label in ["train", "val"]
57 | #assert opt.data_mode in ["ours", "wang2020", "ours_wang2020"]
58 | self.data_label = opt.data_label
59 | if opt.data_mode == 'ours':
60 | pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle"
61 | real_list = get_list( os.path.join(opt.real_list_path, pickle_name) )
62 | fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) )
63 | elif opt.data_mode == 'wang2020':
64 | temp = 'train/progan' if opt.data_label == 'train' else 'test/progan'
65 | real_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' )
66 | fake_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' )
67 | elif opt.data_mode == 'ours_wang2020':
68 | pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle"
69 | real_list = get_list( os.path.join(opt.real_list_path, pickle_name) )
70 | fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) )
71 | temp = 'train/progan' if opt.data_label == 'train' else 'test/progan'
72 | real_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' )
73 | fake_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' )
74 |
75 |
76 |
77 | # setting the labels for the dataset
78 | self.labels_dict = {}
79 | for i in real_list:
80 | self.labels_dict[i] = 0
81 | for i in fake_list:
82 | self.labels_dict[i] = 1
83 |
84 | self.total_list = real_list + fake_list
85 | shuffle(self.total_list)
86 | if opt.isTrain:
87 | crop_func = transforms.RandomCrop(opt.cropSize)
88 | elif opt.no_crop:
89 | crop_func = transforms.Lambda(lambda img: img)
90 | else:
91 | crop_func = transforms.CenterCrop(opt.cropSize)
92 |
93 | if opt.isTrain and not opt.no_flip:
94 | flip_func = transforms.RandomHorizontalFlip()
95 | else:
96 | flip_func = transforms.Lambda(lambda img: img)
97 | if not opt.isTrain and opt.no_resize:
98 | rz_func = transforms.Lambda(lambda img: img)
99 | else:
100 | rz_func = transforms.Lambda(lambda img: custom_resize(img, opt))
101 |
102 |
103 | stat_from = "imagenet" if opt.arch.lower().startswith("imagenet") else "clip"
104 |
105 | print("mean and std stats are from: ", stat_from)
106 | if '2b' not in opt.arch:
107 | print ("using Official CLIP's normalization")
108 | self.transform = transforms.Compose([
109 | rz_func,
110 | transforms.Lambda(lambda img: data_augment(img, opt)),
111 | crop_func,
112 | flip_func,
113 | transforms.ToTensor(),
114 | transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ),
115 | ])
116 | else:
117 | print ("Using CLIP 2B transform")
118 | self.transform = None # will be initialized in trainer.py
119 |
120 |
121 | def __len__(self):
122 | return len(self.total_list)
123 |
124 |
125 | def __getitem__(self, idx):
126 | img_path = self.total_list[idx]
127 | label = self.labels_dict[img_path]
128 | img = Image.open(img_path).convert("RGB")
129 | img = self.transform(img)
130 | return img, label
131 |
132 |
133 | def data_augment(img, opt):
134 | img = np.array(img)
135 | if img.ndim == 2:
136 | img = np.expand_dims(img, axis=2)
137 | img = np.repeat(img, 3, axis=2)
138 |
139 | if random() < opt.blur_prob:
140 | sig = sample_continuous(opt.blur_sig)
141 | gaussian_blur(img, sig)
142 |
143 | if random() < opt.jpg_prob:
144 | method = sample_discrete(opt.jpg_method)
145 | qual = sample_discrete(opt.jpg_qual)
146 | img = jpeg_from_key(img, qual, method)
147 |
148 | return Image.fromarray(img)
149 |
150 |
151 | def sample_continuous(s):
152 | if len(s) == 1:
153 | return s[0]
154 | if len(s) == 2:
155 | rg = s[1] - s[0]
156 | return random() * rg + s[0]
157 | raise ValueError("Length of iterable s should be 1 or 2.")
158 |
159 |
160 | def sample_discrete(s):
161 | if len(s) == 1:
162 | return s[0]
163 | return choice(s)
164 |
165 |
166 | def gaussian_blur(img, sigma):
167 | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma)
168 | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma)
169 | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma)
170 |
171 |
172 | def cv2_jpg(img, compress_val):
173 | img_cv2 = img[:,:,::-1]
174 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
175 | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
176 | decimg = cv2.imdecode(encimg, 1)
177 | return decimg[:,:,::-1]
178 |
179 |
180 | def pil_jpg(img, compress_val):
181 | out = BytesIO()
182 | img = Image.fromarray(img)
183 | img.save(out, format='jpeg', quality=compress_val)
184 | img = Image.open(out)
185 | # load from memory before ByteIO closes
186 | img = np.array(img)
187 | out.close()
188 | return img
189 |
190 |
191 | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg}
192 | def jpeg_from_key(img, compress_val, key):
193 | method = jpeg_dict[key]
194 | return method(img, compress_val)
195 |
196 |
197 | rz_dict = {'bilinear': Image.BILINEAR,
198 | 'bicubic': Image.BICUBIC,
199 | 'lanczos': Image.LANCZOS,
200 | 'nearest': Image.NEAREST}
201 | def custom_resize(img, opt):
202 | interp = sample_discrete(opt.rz_interp)
203 | return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp])
204 |
--------------------------------------------------------------------------------
/dataset_paths.py:
--------------------------------------------------------------------------------
1 | DATASET_PATHS = [
2 |
3 |
4 | dict(
5 | real_path='../FAKE_IMAGES/CNN/test/progan',
6 | fake_path='../FAKE_IMAGES/CNN/test/progan',
7 | data_mode='wang2020',
8 | key='progan'
9 | ),
10 |
11 | dict(
12 | real_path='../FAKE_IMAGES/CNN/test/cyclegan',
13 | fake_path='../FAKE_IMAGES/CNN/test/cyclegan',
14 | data_mode='wang2020',
15 | key='cyclegan'
16 | ),
17 |
18 | dict(
19 | real_path='../FAKE_IMAGES/CNN/test/biggan/', # Imagenet
20 | fake_path='../FAKE_IMAGES/CNN/test/biggan/',
21 | data_mode='wang2020',
22 | key='biggan'
23 | ),
24 |
25 |
26 | dict(
27 | real_path='../FAKE_IMAGES/CNN/test/stylegan',
28 | fake_path='../FAKE_IMAGES/CNN/test/stylegan',
29 | data_mode='wang2020',
30 | key='stylegan'
31 | ),
32 |
33 |
34 | dict(
35 | real_path='../FAKE_IMAGES/CNN/test/gaugan', # It is COCO
36 | fake_path='../FAKE_IMAGES/CNN/test/gaugan',
37 | data_mode='wang2020',
38 | key='gaugan'
39 | ),
40 |
41 |
42 | dict(
43 | real_path='../FAKE_IMAGES/CNN/test/stargan',
44 | fake_path='../FAKE_IMAGES/CNN/test/stargan',
45 | data_mode='wang2020',
46 | key='stargan'
47 | ),
48 |
49 |
50 | dict(
51 | real_path='../FAKE_IMAGES/CNN/test/deepfake',
52 | fake_path='../FAKE_IMAGES/CNN/test/deepfake',
53 | data_mode='wang2020',
54 | key='deepfake'
55 | ),
56 |
57 |
58 | dict(
59 | real_path='../FAKE_IMAGES/CNN/test/seeingdark',
60 | fake_path='../FAKE_IMAGES/CNN/test/seeingdark',
61 | data_mode='wang2020',
62 | key='sitd'
63 | ),
64 |
65 |
66 | dict(
67 | real_path='../FAKE_IMAGES/CNN/test/san',
68 | fake_path='../FAKE_IMAGES/CNN/test/san',
69 | data_mode='wang2020',
70 | key='san'
71 | ),
72 |
73 |
74 | dict(
75 | real_path='../FAKE_IMAGES/CNN/test/crn', # Images from some video games
76 | fake_path='../FAKE_IMAGES/CNN/test/crn',
77 | data_mode='wang2020',
78 | key='crn'
79 | ),
80 |
81 |
82 | dict(
83 | real_path='../FAKE_IMAGES/CNN/test/imle', # Images from some video games
84 | fake_path='../FAKE_IMAGES/CNN/test/imle',
85 | data_mode='wang2020',
86 | key='imle'
87 | ),
88 |
89 |
90 | dict(
91 | real_path='./diffusion_datasets/imagenet',
92 | fake_path='./diffusion_datasets/guided',
93 | data_mode='wang2020',
94 | key='guided'
95 | ),
96 |
97 |
98 | dict(
99 | real_path='./diffusion_datasets/laion',
100 | fake_path='./diffusion_datasets/ldm_200',
101 | data_mode='wang2020',
102 | key='ldm_200'
103 | ),
104 |
105 | dict(
106 | real_path='./diffusion_datasets/laion',
107 | fake_path='./diffusion_datasets/ldm_200_cfg',
108 | data_mode='wang2020',
109 | key='ldm_200_cfg'
110 | ),
111 |
112 | dict(
113 | real_path='./diffusion_datasets/laion',
114 | fake_path='./diffusion_datasets/ldm_100',
115 | data_mode='wang2020',
116 | key='ldm_100'
117 | ),
118 |
119 |
120 | dict(
121 | real_path='./diffusion_datasets/laion',
122 | fake_path='./diffusion_datasets/glide_100_27',
123 | data_mode='wang2020',
124 | key='glide_100_27'
125 | ),
126 |
127 |
128 | dict(
129 | real_path='./diffusion_datasets/laion',
130 | fake_path='./diffusion_datasets/glide_50_27',
131 | data_mode='wang2020',
132 | key='glide_50_27'
133 | ),
134 |
135 |
136 | dict(
137 | real_path='./diffusion_datasets/laion',
138 | fake_path='./diffusion_datasets/glide_100_10',
139 | data_mode='wang2020',
140 | key='glide_100_10'
141 | ),
142 |
143 |
144 | dict(
145 | real_path='./diffusion_datasets/laion',
146 | fake_path='./diffusion_datasets/dalle',
147 | data_mode='wang2020',
148 | key='dalle'
149 | ),
150 |
151 |
152 |
153 | ]
154 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip_models import CLIPModel
2 | from .imagenet_models import ImagenetModel
3 |
4 |
5 | VALID_NAMES = [
6 | 'Imagenet:resnet18',
7 | 'Imagenet:resnet34',
8 | 'Imagenet:resnet50',
9 | 'Imagenet:resnet101',
10 | 'Imagenet:resnet152',
11 | 'Imagenet:vgg11',
12 | 'Imagenet:vgg19',
13 | 'Imagenet:swin-b',
14 | 'Imagenet:swin-s',
15 | 'Imagenet:swin-t',
16 | 'Imagenet:vit_b_16',
17 | 'Imagenet:vit_b_32',
18 | 'Imagenet:vit_l_16',
19 | 'Imagenet:vit_l_32',
20 |
21 | 'CLIP:RN50',
22 | 'CLIP:RN101',
23 | 'CLIP:RN50x4',
24 | 'CLIP:RN50x16',
25 | 'CLIP:RN50x64',
26 | 'CLIP:ViT-B/32',
27 | 'CLIP:ViT-B/16',
28 | 'CLIP:ViT-L/14',
29 | 'CLIP:ViT-L/14@336px',
30 | ]
31 |
32 |
33 |
34 |
35 |
36 | def get_model(name):
37 | assert name in VALID_NAMES
38 | if name.startswith("Imagenet:"):
39 | return ImagenetModel(name[9:])
40 | elif name.startswith("CLIP:"):
41 | return CLIPModel(name[5:])
42 | else:
43 | assert False
44 |
--------------------------------------------------------------------------------
/models/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/models/clip/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/clip.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/clip.cpython-310.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/clip.cpython-38.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/clip.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/clip.cpython-39.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc
--------------------------------------------------------------------------------
/models/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/models/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/models/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95 | """Load a CLIP model
96 |
97 | Parameters
98 | ----------
99 | name : str
100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101 |
102 | device : Union[str, torch.device]
103 | The device to put the loaded model
104 |
105 | jit : bool
106 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
107 |
108 | download_root: str
109 | path to download the model files; by default, it uses "~/.cache/clip"
110 |
111 | Returns
112 | -------
113 | model : torch.nn.Module
114 | The CLIP model
115 |
116 | preprocess : Callable[[PIL.Image], torch.Tensor]
117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118 | """
119 | if name in _MODELS:
120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121 | elif os.path.isfile(name):
122 | model_path = name
123 | else:
124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125 |
126 | with open(model_path, 'rb') as opened_file:
127 | try:
128 | # loading JIT archive
129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130 | state_dict = None
131 | except RuntimeError:
132 | # loading saved state dict
133 | if jit:
134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135 | jit = False
136 | state_dict = torch.load(opened_file, map_location="cpu")
137 |
138 | if not jit:
139 | model = build_model(state_dict or model.state_dict()).to(device)
140 | if str(device) == "cpu":
141 | model.float()
142 | return model, _transform(model.visual.input_resolution)
143 |
144 | # patch the device names
145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147 |
148 | def patch_device(module):
149 | try:
150 | graphs = [module.graph] if hasattr(module, "graph") else []
151 | except RuntimeError:
152 | graphs = []
153 |
154 | if hasattr(module, "forward1"):
155 | graphs.append(module.forward1.graph)
156 |
157 | for graph in graphs:
158 | for node in graph.findAllNodes("prim::Constant"):
159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160 | node.copyAttributes(device_node)
161 |
162 | model.apply(patch_device)
163 | patch_device(model.encode_image)
164 | patch_device(model.encode_text)
165 |
166 | # patch dtype to float32 on CPU
167 | if str(device) == "cpu":
168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170 | float_node = float_input.node()
171 |
172 | def patch_float(module):
173 | try:
174 | graphs = [module.graph] if hasattr(module, "graph") else []
175 | except RuntimeError:
176 | graphs = []
177 |
178 | if hasattr(module, "forward1"):
179 | graphs.append(module.forward1.graph)
180 |
181 | for graph in graphs:
182 | for node in graph.findAllNodes("aten::to"):
183 | inputs = list(node.inputs())
184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185 | if inputs[i].node()["value"] == 5:
186 | inputs[i].node().copyAttributes(float_node)
187 |
188 | model.apply(patch_float)
189 | patch_float(model.encode_image)
190 | patch_float(model.encode_text)
191 |
192 | model.float()
193 |
194 | return model, _transform(model.input_resolution.item())
195 |
196 |
197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198 | """
199 | Returns the tokenized representation of given input string(s)
200 |
201 | Parameters
202 | ----------
203 | texts : Union[str, List[str]]
204 | An input string or a list of input strings to tokenize
205 |
206 | context_length : int
207 | The context length to use; all CLIP models use 77 as the context length
208 |
209 | truncate: bool
210 | Whether to truncate the text in case its encoding is longer than the context length
211 |
212 | Returns
213 | -------
214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216 | """
217 | if isinstance(texts, str):
218 | texts = [texts]
219 |
220 | sot_token = _tokenizer.encoder["<|startoftext|>"]
221 | eot_token = _tokenizer.encoder["<|endoftext|>"]
222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225 | else:
226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227 |
228 | for i, tokens in enumerate(all_tokens):
229 | if len(tokens) > context_length:
230 | if truncate:
231 | tokens = tokens[:context_length]
232 | tokens[-1] = eot_token
233 | else:
234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235 | result[i, :len(tokens)] = torch.tensor(tokens)
236 |
237 | return result
238 |
--------------------------------------------------------------------------------
/models/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.relu1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.relu2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.relu3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.relu1(self.bn1(self.conv1(x)))
46 | out = self.relu2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x[:1], key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 | return x.squeeze(0)
92 |
93 |
94 | class ModifiedResNet(nn.Module):
95 | """
96 | A ResNet class that is similar to torchvision's but contains the following changes:
97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99 | - The final pooling layer is a QKV attention instead of an average pool
100 | """
101 |
102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103 | super().__init__()
104 | self.output_dim = output_dim
105 | self.input_resolution = input_resolution
106 |
107 | # the 3-layer stem
108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109 | self.bn1 = nn.BatchNorm2d(width // 2)
110 | self.relu1 = nn.ReLU(inplace=True)
111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112 | self.bn2 = nn.BatchNorm2d(width // 2)
113 | self.relu2 = nn.ReLU(inplace=True)
114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115 | self.bn3 = nn.BatchNorm2d(width)
116 | self.relu3 = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(2)
118 |
119 | # residual layers
120 | self._inplanes = width # this is a *mutable* variable used during construction
121 | self.layer1 = self._make_layer(width, layers[0])
122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125 |
126 | embed_dim = width * 32 # the ResNet feature dimension
127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128 |
129 | def _make_layer(self, planes, blocks, stride=1):
130 | layers = [Bottleneck(self._inplanes, planes, stride)]
131 |
132 | self._inplanes = planes * Bottleneck.expansion
133 | for _ in range(1, blocks):
134 | layers.append(Bottleneck(self._inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | def stem(x):
140 | x = self.relu1(self.bn1(self.conv1(x)))
141 | x = self.relu2(self.bn2(self.conv2(x)))
142 | x = self.relu3(self.bn3(self.conv3(x)))
143 | x = self.avgpool(x)
144 | return x
145 |
146 | x = x.type(self.conv1.weight.dtype)
147 | x = stem(x)
148 | x = self.layer1(x)
149 | x = self.layer2(x)
150 | x = self.layer3(x)
151 | x = self.layer4(x)
152 | x = self.attnpool(x)
153 |
154 | return x
155 |
156 |
157 | class LayerNorm(nn.LayerNorm):
158 | """Subclass torch's LayerNorm to handle fp16."""
159 |
160 | def forward(self, x: torch.Tensor):
161 | orig_type = x.dtype
162 | ret = super().forward(x.type(torch.float32))
163 | return ret.type(orig_type)
164 |
165 |
166 | class QuickGELU(nn.Module):
167 | def forward(self, x: torch.Tensor):
168 | return x * torch.sigmoid(1.702 * x)
169 |
170 |
171 | class ResidualAttentionBlock(nn.Module):
172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173 | super().__init__()
174 |
175 | self.attn = nn.MultiheadAttention(d_model, n_head)
176 | self.ln_1 = LayerNorm(d_model)
177 | self.mlp = nn.Sequential(OrderedDict([
178 | ("c_fc", nn.Linear(d_model, d_model * 4)),
179 | ("gelu", QuickGELU()),
180 | ("c_proj", nn.Linear(d_model * 4, d_model))
181 | ]))
182 | self.ln_2 = LayerNorm(d_model)
183 | self.attn_mask = attn_mask
184 |
185 | def attention(self, x: torch.Tensor):
186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188 |
189 | def forward(self, x: torch.Tensor):
190 | x = x + self.attention(self.ln_1(x))
191 | x = x + self.mlp(self.ln_2(x))
192 | return x
193 |
194 |
195 | class Transformer(nn.Module):
196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197 | super().__init__()
198 | self.width = width
199 | self.layers = layers
200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201 |
202 | def forward(self, x: torch.Tensor):
203 | out = {}
204 | for idx, layer in enumerate(self.resblocks.children()):
205 | x = layer(x)
206 | out['layer'+str(idx)] = x[0] # shape:LND. choose cls token feature
207 | return out, x
208 |
209 | # return self.resblocks(x) # This is the original code
210 |
211 |
212 | class VisionTransformer(nn.Module):
213 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
214 | super().__init__()
215 | self.input_resolution = input_resolution
216 | self.output_dim = output_dim
217 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
218 |
219 | scale = width ** -0.5
220 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
221 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
222 | self.ln_pre = LayerNorm(width)
223 |
224 | self.transformer = Transformer(width, layers, heads)
225 |
226 | self.ln_post = LayerNorm(width)
227 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
228 |
229 |
230 |
231 | def forward(self, x: torch.Tensor):
232 | x = self.conv1(x) # shape = [*, width, grid, grid]
233 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
234 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
235 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
236 | x = x + self.positional_embedding.to(x.dtype)
237 | x = self.ln_pre(x)
238 |
239 | x = x.permute(1, 0, 2) # NLD -> LND
240 | out, x = self.transformer(x)
241 | x = x.permute(1, 0, 2) # LND -> NLD
242 |
243 | x = self.ln_post(x[:, 0, :])
244 |
245 |
246 | out['before_projection'] = x
247 |
248 | if self.proj is not None:
249 | x = x @ self.proj
250 | out['after_projection'] = x
251 |
252 | # Return both intermediate features and final clip feature
253 | # return out
254 |
255 | # This only returns CLIP features
256 | return x
257 |
258 |
259 | class CLIP(nn.Module):
260 | def __init__(self,
261 | embed_dim: int,
262 | # vision
263 | image_resolution: int,
264 | vision_layers: Union[Tuple[int, int, int, int], int],
265 | vision_width: int,
266 | vision_patch_size: int,
267 | # text
268 | context_length: int,
269 | vocab_size: int,
270 | transformer_width: int,
271 | transformer_heads: int,
272 | transformer_layers: int
273 | ):
274 | super().__init__()
275 |
276 | self.context_length = context_length
277 |
278 | if isinstance(vision_layers, (tuple, list)):
279 | vision_heads = vision_width * 32 // 64
280 | self.visual = ModifiedResNet(
281 | layers=vision_layers,
282 | output_dim=embed_dim,
283 | heads=vision_heads,
284 | input_resolution=image_resolution,
285 | width=vision_width
286 | )
287 | else:
288 | vision_heads = vision_width // 64
289 | self.visual = VisionTransformer(
290 | input_resolution=image_resolution,
291 | patch_size=vision_patch_size,
292 | width=vision_width,
293 | layers=vision_layers,
294 | heads=vision_heads,
295 | output_dim=embed_dim
296 | )
297 |
298 | self.transformer = Transformer(
299 | width=transformer_width,
300 | layers=transformer_layers,
301 | heads=transformer_heads,
302 | attn_mask=self.build_attention_mask()
303 | )
304 |
305 | self.vocab_size = vocab_size
306 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
307 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
308 | self.ln_final = LayerNorm(transformer_width)
309 |
310 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
311 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
312 |
313 | self.initialize_parameters()
314 |
315 | def initialize_parameters(self):
316 | nn.init.normal_(self.token_embedding.weight, std=0.02)
317 | nn.init.normal_(self.positional_embedding, std=0.01)
318 |
319 | if isinstance(self.visual, ModifiedResNet):
320 | if self.visual.attnpool is not None:
321 | std = self.visual.attnpool.c_proj.in_features ** -0.5
322 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
323 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
324 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
325 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
326 |
327 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
328 | for name, param in resnet_block.named_parameters():
329 | if name.endswith("bn3.weight"):
330 | nn.init.zeros_(param)
331 |
332 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
333 | attn_std = self.transformer.width ** -0.5
334 | fc_std = (2 * self.transformer.width) ** -0.5
335 | for block in self.transformer.resblocks:
336 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
337 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
338 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
339 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
340 |
341 | if self.text_projection is not None:
342 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
343 |
344 | def build_attention_mask(self):
345 | # lazily create causal attention mask, with full attention between the vision tokens
346 | # pytorch uses additive attention mask; fill with -inf
347 | mask = torch.empty(self.context_length, self.context_length)
348 | mask.fill_(float("-inf"))
349 | mask.triu_(1) # zero out the lower diagonal
350 | return mask
351 |
352 | @property
353 | def dtype(self):
354 | return self.visual.conv1.weight.dtype
355 |
356 | def encode_image(self, image):
357 | return self.visual(image.type(self.dtype))
358 |
359 | def encode_text(self, text):
360 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
361 |
362 | x = x + self.positional_embedding.type(self.dtype)
363 | x = x.permute(1, 0, 2) # NLD -> LND
364 | x = self.transformer(x)
365 | x = x.permute(1, 0, 2) # LND -> NLD
366 | x = self.ln_final(x).type(self.dtype)
367 |
368 | # x.shape = [batch_size, n_ctx, transformer.width]
369 | # take features from the eot embedding (eot_token is the highest number in each sequence)
370 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
371 |
372 | return x
373 |
374 | def forward(self, image, text):
375 | image_features = self.encode_image(image)
376 | text_features = self.encode_text(text)
377 |
378 | # normalized features
379 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
380 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
381 |
382 | # cosine similarity as logits
383 | logit_scale = self.logit_scale.exp()
384 | logits_per_image = logit_scale * image_features @ text_features.t()
385 | logits_per_text = logits_per_image.t()
386 |
387 | # shape = [global_batch_size, global_batch_size]
388 | return logits_per_image, logits_per_text
389 |
390 |
391 | def convert_weights(model: nn.Module):
392 | """Convert applicable model parameters to fp16"""
393 |
394 | def _convert_weights_to_fp16(l):
395 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
396 | l.weight.data = l.weight.data.half()
397 | if l.bias is not None:
398 | l.bias.data = l.bias.data.half()
399 |
400 | if isinstance(l, nn.MultiheadAttention):
401 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
402 | tensor = getattr(l, attr)
403 | if tensor is not None:
404 | tensor.data = tensor.data.half()
405 |
406 | for name in ["text_projection", "proj"]:
407 | if hasattr(l, name):
408 | attr = getattr(l, name)
409 | if attr is not None:
410 | attr.data = attr.data.half()
411 |
412 | model.apply(_convert_weights_to_fp16)
413 |
414 |
415 | def build_model(state_dict: dict):
416 | vit = "visual.proj" in state_dict
417 |
418 | if vit:
419 | vision_width = state_dict["visual.conv1.weight"].shape[0]
420 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
421 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
422 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
423 | image_resolution = vision_patch_size * grid_size
424 | else:
425 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
426 | vision_layers = tuple(counts)
427 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
428 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
429 | vision_patch_size = None
430 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
431 | image_resolution = output_width * 32
432 |
433 | embed_dim = state_dict["text_projection"].shape[1]
434 | context_length = state_dict["positional_embedding"].shape[0]
435 | vocab_size = state_dict["token_embedding.weight"].shape[0]
436 | transformer_width = state_dict["ln_final.weight"].shape[0]
437 | transformer_heads = transformer_width // 64
438 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
439 |
440 | model = CLIP(
441 | embed_dim,
442 | image_resolution, vision_layers, vision_width, vision_patch_size,
443 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
444 | )
445 |
446 | for key in ["input_resolution", "context_length", "vocab_size"]:
447 | if key in state_dict:
448 | del state_dict[key]
449 |
450 | convert_weights(model)
451 | model.load_state_dict(state_dict)
452 | return model.eval()
453 |
--------------------------------------------------------------------------------
/models/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/models/clip_models.py:
--------------------------------------------------------------------------------
1 | from .clip import clip
2 | from PIL import Image
3 | import torch.nn as nn
4 |
5 |
6 | CHANNELS = {
7 | "RN50" : 1024,
8 | "ViT-L/14" : 768
9 | }
10 |
11 | class CLIPModel(nn.Module):
12 | def __init__(self, name, num_classes=1):
13 | super(CLIPModel, self).__init__()
14 |
15 | self.model, self.preprocess = clip.load(name, device="cpu") # self.preprecess will not be used during training, which is handled in Dataset class
16 | self.fc = nn.Linear( CHANNELS[name], num_classes )
17 |
18 |
19 | def forward(self, x, return_feature=False):
20 | features = self.model.encode_image(x)
21 | if return_feature:
22 | return features
23 | return self.fc(features)
24 |
25 |
--------------------------------------------------------------------------------
/models/imagenet_models.py:
--------------------------------------------------------------------------------
1 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152
2 | from .vision_transformer import vit_b_16, vit_b_32, vit_l_16, vit_l_32
3 |
4 | from torchvision import transforms
5 | from PIL import Image
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | model_dict = {
11 | 'resnet18': resnet18,
12 | 'resnet34': resnet34,
13 | 'resnet50': resnet50,
14 | 'resnet101': resnet101,
15 | 'resnet152': resnet152,
16 | 'vit_b_16': vit_b_16,
17 | 'vit_b_32': vit_b_32,
18 | 'vit_l_16': vit_l_16,
19 | 'vit_l_32': vit_l_32
20 | }
21 |
22 |
23 | CHANNELS = {
24 | "resnet50" : 2048,
25 | "vit_b_16" : 768,
26 | }
27 |
28 |
29 |
30 | class ImagenetModel(nn.Module):
31 | def __init__(self, name, num_classes=1):
32 | super(ImagenetModel, self).__init__()
33 |
34 | self.model = model_dict[name](pretrained=True)
35 | self.fc = nn.Linear(CHANNELS[name], num_classes) #manually define a fc layer here
36 |
37 |
38 | def forward(self, x):
39 | feature = self.model(x)["penultimate"]
40 | return self.fc(feature)
41 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | from typing import Type, Any, Callable, Union, List, Optional
5 |
6 | try:
7 | from torch.hub import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
22 | }
23 |
24 |
25 |
26 |
27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
28 | """3x3 convolution with padding"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=dilation, groups=groups, bias=False, dilation=dilation)
31 |
32 |
33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
34 | """1x1 convolution"""
35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion: int = 1
40 |
41 | def __init__(
42 | self,
43 | inplanes: int,
44 | planes: int,
45 | stride: int = 1,
46 | downsample: Optional[nn.Module] = None,
47 | groups: int = 1,
48 | base_width: int = 64,
49 | dilation: int = 1,
50 | norm_layer: Optional[Callable[..., nn.Module]] = None
51 | ) -> None:
52 | super(BasicBlock, self).__init__()
53 | if norm_layer is None:
54 | norm_layer = nn.BatchNorm2d
55 | if groups != 1 or base_width != 64:
56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
57 | if dilation > 1:
58 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
60 | self.conv1 = conv3x3(inplanes, planes, stride)
61 | self.bn1 = norm_layer(planes)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.conv2 = conv3x3(planes, planes)
64 | self.bn2 = norm_layer(planes)
65 | self.downsample = downsample
66 | self.stride = stride
67 |
68 | def forward(self, x: Tensor) -> Tensor:
69 | identity = x
70 |
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 |
78 | if self.downsample is not None:
79 | identity = self.downsample(x)
80 |
81 | out += identity
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class Bottleneck(nn.Module):
88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
91 | # This variant is also known as ResNet V1.5 and improves accuracy according to
92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
93 |
94 | expansion: int = 4
95 |
96 | def __init__(
97 | self,
98 | inplanes: int,
99 | planes: int,
100 | stride: int = 1,
101 | downsample: Optional[nn.Module] = None,
102 | groups: int = 1,
103 | base_width: int = 64,
104 | dilation: int = 1,
105 | norm_layer: Optional[Callable[..., nn.Module]] = None
106 | ) -> None:
107 | super(Bottleneck, self).__init__()
108 | if norm_layer is None:
109 | norm_layer = nn.BatchNorm2d
110 | width = int(planes * (base_width / 64.)) * groups
111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
112 | self.conv1 = conv1x1(inplanes, width)
113 | self.bn1 = norm_layer(width)
114 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
115 | self.bn2 = norm_layer(width)
116 | self.conv3 = conv1x1(width, planes * self.expansion)
117 | self.bn3 = norm_layer(planes * self.expansion)
118 | self.relu = nn.ReLU(inplace=True)
119 | self.downsample = downsample
120 | self.stride = stride
121 |
122 | def forward(self, x: Tensor) -> Tensor:
123 | identity = x
124 |
125 | out = self.conv1(x)
126 | out = self.bn1(out)
127 | out = self.relu(out)
128 |
129 | out = self.conv2(out)
130 | out = self.bn2(out)
131 | out = self.relu(out)
132 |
133 | out = self.conv3(out)
134 | out = self.bn3(out)
135 |
136 | if self.downsample is not None:
137 | identity = self.downsample(x)
138 |
139 | out += identity
140 | out = self.relu(out)
141 |
142 | return out
143 |
144 |
145 | class ResNet(nn.Module):
146 |
147 | def __init__(
148 | self,
149 | block: Type[Union[BasicBlock, Bottleneck]],
150 | layers: List[int],
151 | num_classes: int = 1000,
152 | zero_init_residual: bool = False,
153 | groups: int = 1,
154 | width_per_group: int = 64,
155 | replace_stride_with_dilation: Optional[List[bool]] = None,
156 | norm_layer: Optional[Callable[..., nn.Module]] = None
157 | ) -> None:
158 | super(ResNet, self).__init__()
159 | if norm_layer is None:
160 | norm_layer = nn.BatchNorm2d
161 | self._norm_layer = norm_layer
162 |
163 | self.inplanes = 64
164 | self.dilation = 1
165 | if replace_stride_with_dilation is None:
166 | # each element in the tuple indicates if we should replace
167 | # the 2x2 stride with a dilated convolution instead
168 | replace_stride_with_dilation = [False, False, False]
169 | if len(replace_stride_with_dilation) != 3:
170 | raise ValueError("replace_stride_with_dilation should be None "
171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
172 | self.groups = groups
173 | self.base_width = width_per_group
174 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
175 | bias=False)
176 | self.bn1 = norm_layer(self.inplanes)
177 | self.relu = nn.ReLU(inplace=True)
178 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
179 | self.layer1 = self._make_layer(block, 64, layers[0])
180 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
181 | dilate=replace_stride_with_dilation[0])
182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
183 | dilate=replace_stride_with_dilation[1])
184 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
185 | dilate=replace_stride_with_dilation[2])
186 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
187 | self.fc = nn.Linear(512 * block.expansion, num_classes)
188 |
189 | for m in self.modules():
190 | if isinstance(m, nn.Conv2d):
191 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
192 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
193 | nn.init.constant_(m.weight, 1)
194 | nn.init.constant_(m.bias, 0)
195 |
196 | # Zero-initialize the last BN in each residual branch,
197 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
198 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
199 | if zero_init_residual:
200 | for m in self.modules():
201 | if isinstance(m, Bottleneck):
202 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
203 | elif isinstance(m, BasicBlock):
204 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
205 |
206 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
207 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
208 | norm_layer = self._norm_layer
209 | downsample = None
210 | previous_dilation = self.dilation
211 | if dilate:
212 | self.dilation *= stride
213 | stride = 1
214 | if stride != 1 or self.inplanes != planes * block.expansion:
215 | downsample = nn.Sequential(
216 | conv1x1(self.inplanes, planes * block.expansion, stride),
217 | norm_layer(planes * block.expansion),
218 | )
219 |
220 | layers = []
221 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
222 | self.base_width, previous_dilation, norm_layer))
223 | self.inplanes = planes * block.expansion
224 | for _ in range(1, blocks):
225 | layers.append(block(self.inplanes, planes, groups=self.groups,
226 | base_width=self.base_width, dilation=self.dilation,
227 | norm_layer=norm_layer))
228 |
229 | return nn.Sequential(*layers)
230 |
231 | def _forward_impl(self, x):
232 | # The comment resolution is based on input size is 224*224 imagenet
233 | out = {}
234 | x = self.conv1(x)
235 | x = self.bn1(x)
236 | x = self.relu(x)
237 | x = self.maxpool(x)
238 | out['f0'] = x # N*64*56*56
239 |
240 | x = self.layer1(x)
241 | out['f1'] = x # N*64*56*56
242 |
243 | x = self.layer2(x)
244 | out['f2'] = x # N*128*28*28
245 |
246 | x = self.layer3(x)
247 | out['f3'] = x # N*256*14*14
248 |
249 | x = self.layer4(x)
250 | out['f4'] = x # N*512*7*7
251 |
252 | x = self.avgpool(x)
253 | x = torch.flatten(x, 1)
254 | out['penultimate'] = x # N*512
255 |
256 | x = self.fc(x)
257 | out['logits'] = x # N*1000
258 |
259 | # return all features
260 | return out
261 |
262 | # return final classification result
263 | # return x
264 |
265 | def forward(self, x):
266 | return self._forward_impl(x)
267 |
268 |
269 | def _resnet(
270 | arch: str,
271 | block: Type[Union[BasicBlock, Bottleneck]],
272 | layers: List[int],
273 | pretrained: bool,
274 | progress: bool,
275 | **kwargs: Any
276 | ) -> ResNet:
277 | model = ResNet(block, layers, **kwargs)
278 | if pretrained:
279 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
280 | model.load_state_dict(state_dict)
281 | return model
282 |
283 |
284 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
285 | r"""ResNet-18 model from
286 | `"Deep Residual Learning for Image Recognition" `_.
287 |
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
293 |
294 |
295 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
296 | r"""ResNet-34 model from
297 | `"Deep Residual Learning for Image Recognition" `_.
298 |
299 | Args:
300 | pretrained (bool): If True, returns a model pre-trained on ImageNet
301 | progress (bool): If True, displays a progress bar of the download to stderr
302 | """
303 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
304 |
305 |
306 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
307 | r"""ResNet-50 model from
308 | `"Deep Residual Learning for Image Recognition" `_.
309 |
310 | Args:
311 | pretrained (bool): If True, returns a model pre-trained on ImageNet
312 | progress (bool): If True, displays a progress bar of the download to stderr
313 | """
314 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
315 |
316 |
317 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
318 | r"""ResNet-101 model from
319 | `"Deep Residual Learning for Image Recognition" `_.
320 |
321 | Args:
322 | pretrained (bool): If True, returns a model pre-trained on ImageNet
323 | progress (bool): If True, displays a progress bar of the download to stderr
324 | """
325 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
326 |
327 |
328 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
329 | r"""ResNet-152 model from
330 | `"Deep Residual Learning for Image Recognition" `_.
331 |
332 | Args:
333 | pretrained (bool): If True, returns a model pre-trained on ImageNet
334 | progress (bool): If True, displays a progress bar of the download to stderr
335 | """
336 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
337 |
338 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Union, List, Dict, Any, cast
4 | import torchvision
5 | import torch.nn.functional as F
6 |
7 |
8 |
9 |
10 |
11 | class VGG(torch.nn.Module):
12 | def __init__(self, arch_type, pretrained, progress):
13 | super().__init__()
14 |
15 | self.layer1 = torch.nn.Sequential()
16 | self.layer2 = torch.nn.Sequential()
17 | self.layer3 = torch.nn.Sequential()
18 | self.layer4 = torch.nn.Sequential()
19 | self.layer5 = torch.nn.Sequential()
20 |
21 | if arch_type == 'vgg11':
22 | official_vgg = torchvision.models.vgg11(pretrained=pretrained, progress=progress)
23 | blocks = [ [0,2], [2,5], [5,10], [10,15], [15,20] ]
24 | last_idx = 20
25 | elif arch_type == 'vgg19':
26 | official_vgg = torchvision.models.vgg19(pretrained=pretrained, progress=progress)
27 | blocks = [ [0,4], [4,9], [9,18], [18,27], [27,36] ]
28 | last_idx = 36
29 | else:
30 | raise NotImplementedError
31 |
32 |
33 | for x in range( *blocks[0] ):
34 | self.layer1.add_module(str(x), official_vgg.features[x])
35 | for x in range( *blocks[1] ):
36 | self.layer2.add_module(str(x), official_vgg.features[x])
37 | for x in range( *blocks[2] ):
38 | self.layer3.add_module(str(x), official_vgg.features[x])
39 | for x in range( *blocks[3] ):
40 | self.layer4.add_module(str(x), official_vgg.features[x])
41 | for x in range( *blocks[4] ):
42 | self.layer5.add_module(str(x), official_vgg.features[x])
43 |
44 | self.max_pool = official_vgg.features[last_idx]
45 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
46 |
47 | self.fc1 = official_vgg.classifier[0]
48 | self.fc2 = official_vgg.classifier[3]
49 | self.fc3 = official_vgg.classifier[6]
50 | self.dropout = nn.Dropout()
51 |
52 |
53 | def forward(self, x):
54 | out = {}
55 |
56 | x = self.layer1(x)
57 | out['f0'] = x
58 |
59 | x = self.layer2(x)
60 | out['f1'] = x
61 |
62 | x = self.layer3(x)
63 | out['f2'] = x
64 |
65 | x = self.layer4(x)
66 | out['f3'] = x
67 |
68 | x = self.layer5(x)
69 | out['f4'] = x
70 |
71 | x = self.max_pool(x)
72 | x = self.avgpool(x)
73 | x = x.view(-1,512*7*7)
74 |
75 | x = self.fc1(x)
76 | x = F.relu(x)
77 | x = self.dropout(x)
78 | x = self.fc2(x)
79 | x = F.relu(x)
80 | out['penultimate'] = x
81 | x = self.dropout(x)
82 | x = self.fc3(x)
83 | out['logits'] = x
84 |
85 | return out
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 | def vgg11(pretrained=False, progress=True):
97 | r"""VGG 11-layer model (configuration "A") from
98 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
99 |
100 | Args:
101 | pretrained (bool): If True, returns a model pre-trained on ImageNet
102 | progress (bool): If True, displays a progress bar of the download to stderr
103 | """
104 | return VGG('vgg11', pretrained, progress)
105 |
106 |
107 |
108 | def vgg19(pretrained=False, progress=True):
109 | r"""VGG 19-layer model (configuration "E")
110 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
111 |
112 | Args:
113 | pretrained (bool): If True, returns a model pre-trained on ImageNet
114 | progress (bool): If True, displays a progress bar of the download to stderr
115 | """
116 | return VGG('vgg19', pretrained, progress)
117 |
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/models/vision_transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 | from functools import partial
4 | from typing import Any, Callable, List, NamedTuple, Optional
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | # from .._internally_replaced_utils import load_state_dict_from_url
10 | from .vision_transformer_misc import ConvNormActivation
11 | from .vision_transformer_utils import _log_api_usage_once
12 |
13 | try:
14 | from torch.hub import load_state_dict_from_url
15 | except ImportError:
16 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
17 |
18 | # __all__ = [
19 | # "VisionTransformer",
20 | # "vit_b_16",
21 | # "vit_b_32",
22 | # "vit_l_16",
23 | # "vit_l_32",
24 | # ]
25 |
26 | model_urls = {
27 | "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
28 | "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
29 | "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
30 | "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
31 | }
32 |
33 |
34 | class ConvStemConfig(NamedTuple):
35 | out_channels: int
36 | kernel_size: int
37 | stride: int
38 | norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
39 | activation_layer: Callable[..., nn.Module] = nn.ReLU
40 |
41 |
42 | class MLPBlock(nn.Sequential):
43 | """Transformer MLP block."""
44 |
45 | def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
46 | super().__init__()
47 | self.linear_1 = nn.Linear(in_dim, mlp_dim)
48 | self.act = nn.GELU()
49 | self.dropout_1 = nn.Dropout(dropout)
50 | self.linear_2 = nn.Linear(mlp_dim, in_dim)
51 | self.dropout_2 = nn.Dropout(dropout)
52 |
53 | nn.init.xavier_uniform_(self.linear_1.weight)
54 | nn.init.xavier_uniform_(self.linear_2.weight)
55 | nn.init.normal_(self.linear_1.bias, std=1e-6)
56 | nn.init.normal_(self.linear_2.bias, std=1e-6)
57 |
58 |
59 | class EncoderBlock(nn.Module):
60 | """Transformer encoder block."""
61 |
62 | def __init__(
63 | self,
64 | num_heads: int,
65 | hidden_dim: int,
66 | mlp_dim: int,
67 | dropout: float,
68 | attention_dropout: float,
69 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
70 | ):
71 | super().__init__()
72 | self.num_heads = num_heads
73 |
74 | # Attention block
75 | self.ln_1 = norm_layer(hidden_dim)
76 | self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
77 | self.dropout = nn.Dropout(dropout)
78 |
79 | # MLP block
80 | self.ln_2 = norm_layer(hidden_dim)
81 | self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
82 |
83 | def forward(self, input: torch.Tensor):
84 | torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
85 | x = self.ln_1(input)
86 | x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
87 | x = self.dropout(x)
88 | x = x + input
89 |
90 | y = self.ln_2(x)
91 | y = self.mlp(y)
92 | return x + y
93 |
94 |
95 | class Encoder(nn.Module):
96 | """Transformer Model Encoder for sequence to sequence translation."""
97 |
98 | def __init__(
99 | self,
100 | seq_length: int,
101 | num_layers: int,
102 | num_heads: int,
103 | hidden_dim: int,
104 | mlp_dim: int,
105 | dropout: float,
106 | attention_dropout: float,
107 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
108 | ):
109 | super().__init__()
110 | # Note that batch_size is on the first dim because
111 | # we have batch_first=True in nn.MultiAttention() by default
112 | self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
113 | self.dropout = nn.Dropout(dropout)
114 | layers: OrderedDict[str, nn.Module] = OrderedDict()
115 | for i in range(num_layers):
116 | layers[f"encoder_layer_{i}"] = EncoderBlock(
117 | num_heads,
118 | hidden_dim,
119 | mlp_dim,
120 | dropout,
121 | attention_dropout,
122 | norm_layer,
123 | )
124 | self.layers = nn.Sequential(layers)
125 | self.ln = norm_layer(hidden_dim)
126 |
127 | def forward(self, input: torch.Tensor):
128 | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
129 | input = input + self.pos_embedding
130 | return self.ln(self.layers(self.dropout(input)))
131 |
132 |
133 | class VisionTransformer(nn.Module):
134 | """Vision Transformer as per https://arxiv.org/abs/2010.11929."""
135 |
136 | def __init__(
137 | self,
138 | image_size: int,
139 | patch_size: int,
140 | num_layers: int,
141 | num_heads: int,
142 | hidden_dim: int,
143 | mlp_dim: int,
144 | dropout: float = 0.0,
145 | attention_dropout: float = 0.0,
146 | num_classes: int = 1000,
147 | representation_size: Optional[int] = None,
148 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
149 | conv_stem_configs: Optional[List[ConvStemConfig]] = None,
150 | ):
151 | super().__init__()
152 | _log_api_usage_once(self)
153 | torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
154 | self.image_size = image_size
155 | self.patch_size = patch_size
156 | self.hidden_dim = hidden_dim
157 | self.mlp_dim = mlp_dim
158 | self.attention_dropout = attention_dropout
159 | self.dropout = dropout
160 | self.num_classes = num_classes
161 | self.representation_size = representation_size
162 | self.norm_layer = norm_layer
163 |
164 | if conv_stem_configs is not None:
165 | # As per https://arxiv.org/abs/2106.14881
166 | seq_proj = nn.Sequential()
167 | prev_channels = 3
168 | for i, conv_stem_layer_config in enumerate(conv_stem_configs):
169 | seq_proj.add_module(
170 | f"conv_bn_relu_{i}",
171 | ConvNormActivation(
172 | in_channels=prev_channels,
173 | out_channels=conv_stem_layer_config.out_channels,
174 | kernel_size=conv_stem_layer_config.kernel_size,
175 | stride=conv_stem_layer_config.stride,
176 | norm_layer=conv_stem_layer_config.norm_layer,
177 | activation_layer=conv_stem_layer_config.activation_layer,
178 | ),
179 | )
180 | prev_channels = conv_stem_layer_config.out_channels
181 | seq_proj.add_module(
182 | "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
183 | )
184 | self.conv_proj: nn.Module = seq_proj
185 | else:
186 | self.conv_proj = nn.Conv2d(
187 | in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
188 | )
189 |
190 | seq_length = (image_size // patch_size) ** 2
191 |
192 | # Add a class token
193 | self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
194 | seq_length += 1
195 |
196 | self.encoder = Encoder(
197 | seq_length,
198 | num_layers,
199 | num_heads,
200 | hidden_dim,
201 | mlp_dim,
202 | dropout,
203 | attention_dropout,
204 | norm_layer,
205 | )
206 | self.seq_length = seq_length
207 |
208 | heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
209 | if representation_size is None:
210 | heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
211 | else:
212 | heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
213 | heads_layers["act"] = nn.Tanh()
214 | heads_layers["head"] = nn.Linear(representation_size, num_classes)
215 |
216 | self.heads = nn.Sequential(heads_layers)
217 |
218 | if isinstance(self.conv_proj, nn.Conv2d):
219 | # Init the patchify stem
220 | fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
221 | nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
222 | if self.conv_proj.bias is not None:
223 | nn.init.zeros_(self.conv_proj.bias)
224 | elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
225 | # Init the last 1x1 conv of the conv stem
226 | nn.init.normal_(
227 | self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
228 | )
229 | if self.conv_proj.conv_last.bias is not None:
230 | nn.init.zeros_(self.conv_proj.conv_last.bias)
231 |
232 | if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
233 | fan_in = self.heads.pre_logits.in_features
234 | nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
235 | nn.init.zeros_(self.heads.pre_logits.bias)
236 |
237 | if isinstance(self.heads.head, nn.Linear):
238 | nn.init.zeros_(self.heads.head.weight)
239 | nn.init.zeros_(self.heads.head.bias)
240 |
241 | def _process_input(self, x: torch.Tensor) -> torch.Tensor:
242 | n, c, h, w = x.shape
243 | p = self.patch_size
244 | torch._assert(h == self.image_size, "Wrong image height!")
245 | torch._assert(w == self.image_size, "Wrong image width!")
246 | n_h = h // p
247 | n_w = w // p
248 |
249 | # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
250 | x = self.conv_proj(x)
251 | # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
252 | x = x.reshape(n, self.hidden_dim, n_h * n_w)
253 |
254 | # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
255 | # The self attention layer expects inputs in the format (N, S, E)
256 | # where S is the source sequence length, N is the batch size, E is the
257 | # embedding dimension
258 | x = x.permute(0, 2, 1)
259 |
260 | return x
261 |
262 | def forward(self, x: torch.Tensor):
263 | out = {}
264 |
265 | # Reshape and permute the input tensor
266 | x = self._process_input(x)
267 | n = x.shape[0]
268 |
269 | # Expand the class token to the full batch
270 | batch_class_token = self.class_token.expand(n, -1, -1)
271 | x = torch.cat([batch_class_token, x], dim=1)
272 |
273 |
274 | x = self.encoder(x)
275 | img_feature = x[:,1:]
276 | H = W = int(self.image_size / self.patch_size)
277 | out['f4'] = img_feature.view(n, H, W, self.hidden_dim).permute(0,3,1,2)
278 |
279 | # Classifier "token" as used by standard language architectures
280 | x = x[:, 0]
281 | out['penultimate'] = x
282 |
283 | x = self.heads(x) # I checked that for all pretrained ViT, this is just a fc
284 | out['logits'] = x
285 |
286 | return out
287 |
288 |
289 | def _vision_transformer(
290 | arch: str,
291 | patch_size: int,
292 | num_layers: int,
293 | num_heads: int,
294 | hidden_dim: int,
295 | mlp_dim: int,
296 | pretrained: bool,
297 | progress: bool,
298 | **kwargs: Any,
299 | ) -> VisionTransformer:
300 | image_size = kwargs.pop("image_size", 224)
301 |
302 | model = VisionTransformer(
303 | image_size=image_size,
304 | patch_size=patch_size,
305 | num_layers=num_layers,
306 | num_heads=num_heads,
307 | hidden_dim=hidden_dim,
308 | mlp_dim=mlp_dim,
309 | **kwargs,
310 | )
311 |
312 | if pretrained:
313 | if arch not in model_urls:
314 | raise ValueError(f"No checkpoint is available for model type '{arch}'!")
315 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
316 | model.load_state_dict(state_dict)
317 |
318 | return model
319 |
320 |
321 | def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
322 | """
323 | Constructs a vit_b_16 architecture from
324 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_.
325 |
326 | Args:
327 | pretrained (bool): If True, returns a model pre-trained on ImageNet
328 | progress (bool): If True, displays a progress bar of the download to stderr
329 | """
330 | return _vision_transformer(
331 | arch="vit_b_16",
332 | patch_size=16,
333 | num_layers=12,
334 | num_heads=12,
335 | hidden_dim=768,
336 | mlp_dim=3072,
337 | pretrained=pretrained,
338 | progress=progress,
339 | **kwargs,
340 | )
341 |
342 |
343 | def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
344 | """
345 | Constructs a vit_b_32 architecture from
346 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_.
347 |
348 | Args:
349 | pretrained (bool): If True, returns a model pre-trained on ImageNet
350 | progress (bool): If True, displays a progress bar of the download to stderr
351 | """
352 | return _vision_transformer(
353 | arch="vit_b_32",
354 | patch_size=32,
355 | num_layers=12,
356 | num_heads=12,
357 | hidden_dim=768,
358 | mlp_dim=3072,
359 | pretrained=pretrained,
360 | progress=progress,
361 | **kwargs,
362 | )
363 |
364 |
365 | def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
366 | """
367 | Constructs a vit_l_16 architecture from
368 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_.
369 |
370 | Args:
371 | pretrained (bool): If True, returns a model pre-trained on ImageNet
372 | progress (bool): If True, displays a progress bar of the download to stderr
373 | """
374 | return _vision_transformer(
375 | arch="vit_l_16",
376 | patch_size=16,
377 | num_layers=24,
378 | num_heads=16,
379 | hidden_dim=1024,
380 | mlp_dim=4096,
381 | pretrained=pretrained,
382 | progress=progress,
383 | **kwargs,
384 | )
385 |
386 |
387 | def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
388 | """
389 | Constructs a vit_l_32 architecture from
390 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_.
391 |
392 | Args:
393 | pretrained (bool): If True, returns a model pre-trained on ImageNet
394 | progress (bool): If True, displays a progress bar of the download to stderr
395 | """
396 | return _vision_transformer(
397 | arch="vit_l_32",
398 | patch_size=32,
399 | num_layers=24,
400 | num_heads=16,
401 | hidden_dim=1024,
402 | mlp_dim=4096,
403 | pretrained=pretrained,
404 | progress=progress,
405 | **kwargs,
406 | )
407 |
408 |
409 | def interpolate_embeddings(
410 | image_size: int,
411 | patch_size: int,
412 | model_state: "OrderedDict[str, torch.Tensor]",
413 | interpolation_mode: str = "bicubic",
414 | reset_heads: bool = False,
415 | ) -> "OrderedDict[str, torch.Tensor]":
416 | """This function helps interpolating positional embeddings during checkpoint loading,
417 | especially when you want to apply a pre-trained model on images with different resolution.
418 |
419 | Args:
420 | image_size (int): Image size of the new model.
421 | patch_size (int): Patch size of the new model.
422 | model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
423 | interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
424 | reset_heads (bool): If true, not copying the state of heads. Default: False.
425 |
426 | Returns:
427 | OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
428 | """
429 | # Shape of pos_embedding is (1, seq_length, hidden_dim)
430 | pos_embedding = model_state["encoder.pos_embedding"]
431 | n, seq_length, hidden_dim = pos_embedding.shape
432 | if n != 1:
433 | raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
434 |
435 | new_seq_length = (image_size // patch_size) ** 2 + 1
436 |
437 | # Need to interpolate the weights for the position embedding.
438 | # We do this by reshaping the positions embeddings to a 2d grid, performing
439 | # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
440 | if new_seq_length != seq_length:
441 | # The class token embedding shouldn't be interpolated so we split it up.
442 | seq_length -= 1
443 | new_seq_length -= 1
444 | pos_embedding_token = pos_embedding[:, :1, :]
445 | pos_embedding_img = pos_embedding[:, 1:, :]
446 |
447 | # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
448 | pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
449 | seq_length_1d = int(math.sqrt(seq_length))
450 | torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")
451 |
452 | # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
453 | pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
454 | new_seq_length_1d = image_size // patch_size
455 |
456 | # Perform interpolation.
457 | # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
458 | new_pos_embedding_img = nn.functional.interpolate(
459 | pos_embedding_img,
460 | size=new_seq_length_1d,
461 | mode=interpolation_mode,
462 | align_corners=True,
463 | )
464 |
465 | # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
466 | new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
467 |
468 | # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
469 | new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
470 | new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
471 |
472 | model_state["encoder.pos_embedding"] = new_pos_embedding
473 |
474 | if reset_heads:
475 | model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
476 | for k, v in model_state.items():
477 | if not k.startswith("heads"):
478 | model_state_copy[k] = v
479 | model_state = model_state_copy
480 |
481 | return model_state
482 |
--------------------------------------------------------------------------------
/models/vision_transformer_misc.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from .vision_transformer_utils import _log_api_usage_once
7 |
8 |
9 | interpolate = torch.nn.functional.interpolate
10 |
11 |
12 | # This is not in nn
13 | class FrozenBatchNorm2d(torch.nn.Module):
14 | """
15 | BatchNorm2d where the batch statistics and the affine parameters are fixed
16 |
17 | Args:
18 | num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
19 | eps (float): a value added to the denominator for numerical stability. Default: 1e-5
20 | """
21 |
22 | def __init__(
23 | self,
24 | num_features: int,
25 | eps: float = 1e-5,
26 | ):
27 | super().__init__()
28 | _log_api_usage_once(self)
29 | self.eps = eps
30 | self.register_buffer("weight", torch.ones(num_features))
31 | self.register_buffer("bias", torch.zeros(num_features))
32 | self.register_buffer("running_mean", torch.zeros(num_features))
33 | self.register_buffer("running_var", torch.ones(num_features))
34 |
35 | def _load_from_state_dict(
36 | self,
37 | state_dict: dict,
38 | prefix: str,
39 | local_metadata: dict,
40 | strict: bool,
41 | missing_keys: List[str],
42 | unexpected_keys: List[str],
43 | error_msgs: List[str],
44 | ):
45 | num_batches_tracked_key = prefix + "num_batches_tracked"
46 | if num_batches_tracked_key in state_dict:
47 | del state_dict[num_batches_tracked_key]
48 |
49 | super()._load_from_state_dict(
50 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
51 | )
52 |
53 | def forward(self, x: Tensor) -> Tensor:
54 | # move reshapes to the beginning
55 | # to make it fuser-friendly
56 | w = self.weight.reshape(1, -1, 1, 1)
57 | b = self.bias.reshape(1, -1, 1, 1)
58 | rv = self.running_var.reshape(1, -1, 1, 1)
59 | rm = self.running_mean.reshape(1, -1, 1, 1)
60 | scale = w * (rv + self.eps).rsqrt()
61 | bias = b - rm * scale
62 | return x * scale + bias
63 |
64 | def __repr__(self) -> str:
65 | return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
66 |
67 |
68 | class ConvNormActivation(torch.nn.Sequential):
69 | """
70 | Configurable block used for Convolution-Normalzation-Activation blocks.
71 |
72 | Args:
73 | in_channels (int): Number of channels in the input image
74 | out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
75 | kernel_size: (int, optional): Size of the convolving kernel. Default: 3
76 | stride (int, optional): Stride of the convolution. Default: 1
77 | padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
78 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
79 | norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
80 | activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
81 | dilation (int): Spacing between kernel elements. Default: 1
82 | inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
83 | bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
84 |
85 | """
86 |
87 | def __init__(
88 | self,
89 | in_channels: int,
90 | out_channels: int,
91 | kernel_size: int = 3,
92 | stride: int = 1,
93 | padding: Optional[int] = None,
94 | groups: int = 1,
95 | norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
96 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
97 | dilation: int = 1,
98 | inplace: Optional[bool] = True,
99 | bias: Optional[bool] = None,
100 | ) -> None:
101 | if padding is None:
102 | padding = (kernel_size - 1) // 2 * dilation
103 | if bias is None:
104 | bias = norm_layer is None
105 | layers = [
106 | torch.nn.Conv2d(
107 | in_channels,
108 | out_channels,
109 | kernel_size,
110 | stride,
111 | padding,
112 | dilation=dilation,
113 | groups=groups,
114 | bias=bias,
115 | )
116 | ]
117 | if norm_layer is not None:
118 | layers.append(norm_layer(out_channels))
119 | if activation_layer is not None:
120 | params = {} if inplace is None else {"inplace": inplace}
121 | layers.append(activation_layer(**params))
122 | super().__init__(*layers)
123 | _log_api_usage_once(self)
124 | self.out_channels = out_channels
125 |
126 |
127 | class SqueezeExcitation(torch.nn.Module):
128 | """
129 | This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
130 | Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3.
131 |
132 | Args:
133 | input_channels (int): Number of channels in the input image
134 | squeeze_channels (int): Number of squeeze channels
135 | activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
136 | scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
137 | """
138 |
139 | def __init__(
140 | self,
141 | input_channels: int,
142 | squeeze_channels: int,
143 | activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
144 | scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
145 | ) -> None:
146 | super().__init__()
147 | _log_api_usage_once(self)
148 | self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
149 | self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
150 | self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
151 | self.activation = activation()
152 | self.scale_activation = scale_activation()
153 |
154 | def _scale(self, input: Tensor) -> Tensor:
155 | scale = self.avgpool(input)
156 | scale = self.fc1(scale)
157 | scale = self.activation(scale)
158 | scale = self.fc2(scale)
159 | return self.scale_activation(scale)
160 |
161 | def forward(self, input: Tensor) -> Tensor:
162 | scale = self._scale(input)
163 | return scale * input
164 |
--------------------------------------------------------------------------------
/models/vision_transformer_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pathlib
3 | import warnings
4 | from types import FunctionType
5 | from typing import Any, BinaryIO, List, Optional, Tuple, Union
6 |
7 | import numpy as np
8 | import torch
9 | from PIL import Image, ImageColor, ImageDraw, ImageFont
10 |
11 | __all__ = [
12 | "make_grid",
13 | "save_image",
14 | "draw_bounding_boxes",
15 | "draw_segmentation_masks",
16 | "draw_keypoints",
17 | "flow_to_image",
18 | ]
19 |
20 |
21 | @torch.no_grad()
22 | def make_grid(
23 | tensor: Union[torch.Tensor, List[torch.Tensor]],
24 | nrow: int = 8,
25 | padding: int = 2,
26 | normalize: bool = False,
27 | value_range: Optional[Tuple[int, int]] = None,
28 | scale_each: bool = False,
29 | pad_value: float = 0.0,
30 | **kwargs,
31 | ) -> torch.Tensor:
32 | """
33 | Make a grid of images.
34 |
35 | Args:
36 | tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
37 | or a list of images all of the same size.
38 | nrow (int, optional): Number of images displayed in each row of the grid.
39 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
40 | padding (int, optional): amount of padding. Default: ``2``.
41 | normalize (bool, optional): If True, shift the image to the range (0, 1),
42 | by the min and max values specified by ``value_range``. Default: ``False``.
43 | value_range (tuple, optional): tuple (min, max) where min and max are numbers,
44 | then these numbers are used to normalize the image. By default, min and max
45 | are computed from the tensor.
46 | range (tuple. optional):
47 | .. warning::
48 | This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range``
49 | instead.
50 | scale_each (bool, optional): If ``True``, scale each image in the batch of
51 | images separately rather than the (min, max) over all images. Default: ``False``.
52 | pad_value (float, optional): Value for the padded pixels. Default: ``0``.
53 |
54 | Returns:
55 | grid (Tensor): the tensor containing grid of images.
56 | """
57 | if not torch.jit.is_scripting() and not torch.jit.is_tracing():
58 | _log_api_usage_once(make_grid)
59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
61 |
62 | if "range" in kwargs.keys():
63 | warnings.warn(
64 | "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. "
65 | "Please use 'value_range' instead."
66 | )
67 | value_range = kwargs["range"]
68 |
69 | # if list of tensors, convert to a 4D mini-batch Tensor
70 | if isinstance(tensor, list):
71 | tensor = torch.stack(tensor, dim=0)
72 |
73 | if tensor.dim() == 2: # single image H x W
74 | tensor = tensor.unsqueeze(0)
75 | if tensor.dim() == 3: # single image
76 | if tensor.size(0) == 1: # if single-channel, convert to 3-channel
77 | tensor = torch.cat((tensor, tensor, tensor), 0)
78 | tensor = tensor.unsqueeze(0)
79 |
80 | if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
81 | tensor = torch.cat((tensor, tensor, tensor), 1)
82 |
83 | if normalize is True:
84 | tensor = tensor.clone() # avoid modifying tensor in-place
85 | if value_range is not None:
86 | assert isinstance(
87 | value_range, tuple
88 | ), "value_range has to be a tuple (min, max) if specified. min and max are numbers"
89 |
90 | def norm_ip(img, low, high):
91 | img.clamp_(min=low, max=high)
92 | img.sub_(low).div_(max(high - low, 1e-5))
93 |
94 | def norm_range(t, value_range):
95 | if value_range is not None:
96 | norm_ip(t, value_range[0], value_range[1])
97 | else:
98 | norm_ip(t, float(t.min()), float(t.max()))
99 |
100 | if scale_each is True:
101 | for t in tensor: # loop over mini-batch dimension
102 | norm_range(t, value_range)
103 | else:
104 | norm_range(tensor, value_range)
105 |
106 | assert isinstance(tensor, torch.Tensor)
107 | if tensor.size(0) == 1:
108 | return tensor.squeeze(0)
109 |
110 | # make the mini-batch of images into a grid
111 | nmaps = tensor.size(0)
112 | xmaps = min(nrow, nmaps)
113 | ymaps = int(math.ceil(float(nmaps) / xmaps))
114 | height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
115 | num_channels = tensor.size(1)
116 | grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
117 | k = 0
118 | for y in range(ymaps):
119 | for x in range(xmaps):
120 | if k >= nmaps:
121 | break
122 | # Tensor.copy_() is a valid method but seems to be missing from the stubs
123 | # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
124 | grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
125 | 2, x * width + padding, width - padding
126 | ).copy_(tensor[k])
127 | k = k + 1
128 | return grid
129 |
130 |
131 | @torch.no_grad()
132 | def save_image(
133 | tensor: Union[torch.Tensor, List[torch.Tensor]],
134 | fp: Union[str, pathlib.Path, BinaryIO],
135 | format: Optional[str] = None,
136 | **kwargs,
137 | ) -> None:
138 | """
139 | Save a given Tensor into an image file.
140 |
141 | Args:
142 | tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
143 | saves the tensor as a grid of images by calling ``make_grid``.
144 | fp (string or file object): A filename or a file object
145 | format(Optional): If omitted, the format to use is determined from the filename extension.
146 | If a file object was used instead of a filename, this parameter should always be used.
147 | **kwargs: Other arguments are documented in ``make_grid``.
148 | """
149 |
150 | if not torch.jit.is_scripting() and not torch.jit.is_tracing():
151 | _log_api_usage_once(save_image)
152 | grid = make_grid(tensor, **kwargs)
153 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
154 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
155 | im = Image.fromarray(ndarr)
156 | im.save(fp, format=format)
157 |
158 |
159 | @torch.no_grad()
160 | def draw_bounding_boxes(
161 | image: torch.Tensor,
162 | boxes: torch.Tensor,
163 | labels: Optional[List[str]] = None,
164 | colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
165 | fill: Optional[bool] = False,
166 | width: int = 1,
167 | font: Optional[str] = None,
168 | font_size: int = 10,
169 | ) -> torch.Tensor:
170 |
171 | """
172 | Draws bounding boxes on given image.
173 | The values of the input image should be uint8 between 0 and 255.
174 | If fill is True, Resulting Tensor should be saved as PNG image.
175 |
176 | Args:
177 | image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
178 | boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
179 | the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
180 | `0 <= ymin < ymax < H`.
181 | labels (List[str]): List containing the labels of bounding boxes.
182 | colors (color or list of colors, optional): List containing the colors
183 | of the boxes or single color for all boxes. The color can be represented as
184 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
185 | By default, random colors are generated for boxes.
186 | fill (bool): If `True` fills the bounding box with specified color.
187 | width (int): Width of bounding box.
188 | font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
189 | also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
190 | `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
191 | font_size (int): The requested font size in points.
192 |
193 | Returns:
194 | img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
195 | """
196 |
197 | if not torch.jit.is_scripting() and not torch.jit.is_tracing():
198 | _log_api_usage_once(draw_bounding_boxes)
199 | if not isinstance(image, torch.Tensor):
200 | raise TypeError(f"Tensor expected, got {type(image)}")
201 | elif image.dtype != torch.uint8:
202 | raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
203 | elif image.dim() != 3:
204 | raise ValueError("Pass individual images, not batches")
205 | elif image.size(0) not in {1, 3}:
206 | raise ValueError("Only grayscale and RGB images are supported")
207 |
208 | num_boxes = boxes.shape[0]
209 |
210 | if labels is None:
211 | labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
212 | elif len(labels) != num_boxes:
213 | raise ValueError(
214 | f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
215 | )
216 |
217 | if colors is None:
218 | colors = _generate_color_palette(num_boxes)
219 | elif isinstance(colors, list):
220 | if len(colors) < num_boxes:
221 | raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
222 | else: # colors specifies a single color for all boxes
223 | colors = [colors] * num_boxes
224 |
225 | colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
226 |
227 | # Handle Grayscale images
228 | if image.size(0) == 1:
229 | image = torch.tile(image, (3, 1, 1))
230 |
231 | ndarr = image.permute(1, 2, 0).cpu().numpy()
232 | img_to_draw = Image.fromarray(ndarr)
233 | img_boxes = boxes.to(torch.int64).tolist()
234 |
235 | if fill:
236 | draw = ImageDraw.Draw(img_to_draw, "RGBA")
237 | else:
238 | draw = ImageDraw.Draw(img_to_draw)
239 |
240 | txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
241 |
242 | for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
243 | if fill:
244 | fill_color = color + (100,)
245 | draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
246 | else:
247 | draw.rectangle(bbox, width=width, outline=color)
248 |
249 | if label is not None:
250 | margin = width + 1
251 | draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
252 |
253 | return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
254 |
255 |
256 | @torch.no_grad()
257 | def draw_segmentation_masks(
258 | image: torch.Tensor,
259 | masks: torch.Tensor,
260 | alpha: float = 0.8,
261 | colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
262 | ) -> torch.Tensor:
263 |
264 | """
265 | Draws segmentation masks on given RGB image.
266 | The values of the input image should be uint8 between 0 and 255.
267 |
268 | Args:
269 | image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
270 | masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
271 | alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
272 | 0 means full transparency, 1 means no transparency.
273 | colors (color or list of colors, optional): List containing the colors
274 | of the masks or single color for all masks. The color can be represented as
275 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
276 | By default, random colors are generated for each mask.
277 |
278 | Returns:
279 | img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
280 | """
281 |
282 | if not torch.jit.is_scripting() and not torch.jit.is_tracing():
283 | _log_api_usage_once(draw_segmentation_masks)
284 | if not isinstance(image, torch.Tensor):
285 | raise TypeError(f"The image must be a tensor, got {type(image)}")
286 | elif image.dtype != torch.uint8:
287 | raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
288 | elif image.dim() != 3:
289 | raise ValueError("Pass individual images, not batches")
290 | elif image.size()[0] != 3:
291 | raise ValueError("Pass an RGB image. Other Image formats are not supported")
292 | if masks.ndim == 2:
293 | masks = masks[None, :, :]
294 | if masks.ndim != 3:
295 | raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
296 | if masks.dtype != torch.bool:
297 | raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
298 | if masks.shape[-2:] != image.shape[-2:]:
299 | raise ValueError("The image and the masks must have the same height and width")
300 |
301 | num_masks = masks.size()[0]
302 | if colors is not None and num_masks > len(colors):
303 | raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
304 |
305 | if colors is None:
306 | colors = _generate_color_palette(num_masks)
307 |
308 | if not isinstance(colors, list):
309 | colors = [colors]
310 | if not isinstance(colors[0], (tuple, str)):
311 | raise ValueError("colors must be a tuple or a string, or a list thereof")
312 | if isinstance(colors[0], tuple) and len(colors[0]) != 3:
313 | raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
314 |
315 | out_dtype = torch.uint8
316 |
317 | colors_ = []
318 | for color in colors:
319 | if isinstance(color, str):
320 | color = ImageColor.getrgb(color)
321 | colors_.append(torch.tensor(color, dtype=out_dtype))
322 |
323 | img_to_draw = image.detach().clone()
324 | # TODO: There might be a way to vectorize this
325 | for mask, color in zip(masks, colors_):
326 | img_to_draw[:, mask] = color[:, None]
327 |
328 | out = image * (1 - alpha) + img_to_draw * alpha
329 | return out.to(out_dtype)
330 |
331 |
332 | @torch.no_grad()
333 | def draw_keypoints(
334 | image: torch.Tensor,
335 | keypoints: torch.Tensor,
336 | connectivity: Optional[List[Tuple[int, int]]] = None,
337 | colors: Optional[Union[str, Tuple[int, int, int]]] = None,
338 | radius: int = 2,
339 | width: int = 3,
340 | ) -> torch.Tensor:
341 |
342 | """
343 | Draws Keypoints on given RGB image.
344 | The values of the input image should be uint8 between 0 and 255.
345 |
346 | Args:
347 | image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
348 | keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
349 | in the format [x, y].
350 | connectivity (List[Tuple[int, int]]]): A List of tuple where,
351 | each tuple contains pair of keypoints to be connected.
352 | colors (str, Tuple): The color can be represented as
353 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
354 | radius (int): Integer denoting radius of keypoint.
355 | width (int): Integer denoting width of line connecting keypoints.
356 |
357 | Returns:
358 | img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
359 | """
360 |
361 | if not torch.jit.is_scripting() and not torch.jit.is_tracing():
362 | _log_api_usage_once(draw_keypoints)
363 | if not isinstance(image, torch.Tensor):
364 | raise TypeError(f"The image must be a tensor, got {type(image)}")
365 | elif image.dtype != torch.uint8:
366 | raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
367 | elif image.dim() != 3:
368 | raise ValueError("Pass individual images, not batches")
369 | elif image.size()[0] != 3:
370 | raise ValueError("Pass an RGB image. Other Image formats are not supported")
371 |
372 | if keypoints.ndim != 3:
373 | raise ValueError("keypoints must be of shape (num_instances, K, 2)")
374 |
375 | ndarr = image.permute(1, 2, 0).cpu().numpy()
376 | img_to_draw = Image.fromarray(ndarr)
377 | draw = ImageDraw.Draw(img_to_draw)
378 | img_kpts = keypoints.to(torch.int64).tolist()
379 |
380 | for kpt_id, kpt_inst in enumerate(img_kpts):
381 | for inst_id, kpt in enumerate(kpt_inst):
382 | x1 = kpt[0] - radius
383 | x2 = kpt[0] + radius
384 | y1 = kpt[1] - radius
385 | y2 = kpt[1] + radius
386 | draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
387 |
388 | if connectivity:
389 | for connection in connectivity:
390 | start_pt_x = kpt_inst[connection[0]][0]
391 | start_pt_y = kpt_inst[connection[0]][1]
392 |
393 | end_pt_x = kpt_inst[connection[1]][0]
394 | end_pt_y = kpt_inst[connection[1]][1]
395 |
396 | draw.line(
397 | ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
398 | width=width,
399 | )
400 |
401 | return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
402 |
403 |
404 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
405 | @torch.no_grad()
406 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
407 |
408 | """
409 | Converts a flow to an RGB image.
410 |
411 | Args:
412 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
413 |
414 | Returns:
415 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds
416 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
417 | """
418 |
419 | if flow.dtype != torch.float:
420 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
421 |
422 | orig_shape = flow.shape
423 | if flow.ndim == 3:
424 | flow = flow[None] # Add batch dim
425 |
426 | if flow.ndim != 4 or flow.shape[1] != 2:
427 | raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
428 |
429 | max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
430 | epsilon = torch.finfo((flow).dtype).eps
431 | normalized_flow = flow / (max_norm + epsilon)
432 | img = _normalized_flow_to_image(normalized_flow)
433 |
434 | if len(orig_shape) == 3:
435 | img = img[0] # Remove batch dim
436 | return img
437 |
438 |
439 | @torch.no_grad()
440 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
441 |
442 | """
443 | Converts a batch of normalized flow to an RGB image.
444 |
445 | Args:
446 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
447 | Returns:
448 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
449 | """
450 |
451 | N, _, H, W = normalized_flow.shape
452 | device = normalized_flow.device
453 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
454 | colorwheel = _make_colorwheel().to(device) # shape [55x3]
455 | num_cols = colorwheel.shape[0]
456 | norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
457 | a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
458 | fk = (a + 1) / 2 * (num_cols - 1)
459 | k0 = torch.floor(fk).to(torch.long)
460 | k1 = k0 + 1
461 | k1[k1 == num_cols] = 0
462 | f = fk - k0
463 |
464 | for c in range(colorwheel.shape[1]):
465 | tmp = colorwheel[:, c]
466 | col0 = tmp[k0] / 255.0
467 | col1 = tmp[k1] / 255.0
468 | col = (1 - f) * col0 + f * col1
469 | col = 1 - norm * (1 - col)
470 | flow_image[:, c, :, :] = torch.floor(255 * col)
471 | return flow_image
472 |
473 |
474 | def _make_colorwheel() -> torch.Tensor:
475 | """
476 | Generates a color wheel for optical flow visualization as presented in:
477 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
478 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
479 |
480 | Returns:
481 | colorwheel (Tensor[55, 3]): Colorwheel Tensor.
482 | """
483 |
484 | RY = 15
485 | YG = 6
486 | GC = 4
487 | CB = 11
488 | BM = 13
489 | MR = 6
490 |
491 | ncols = RY + YG + GC + CB + BM + MR
492 | colorwheel = torch.zeros((ncols, 3))
493 | col = 0
494 |
495 | # RY
496 | colorwheel[0:RY, 0] = 255
497 | colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
498 | col = col + RY
499 | # YG
500 | colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
501 | colorwheel[col : col + YG, 1] = 255
502 | col = col + YG
503 | # GC
504 | colorwheel[col : col + GC, 1] = 255
505 | colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
506 | col = col + GC
507 | # CB
508 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
509 | colorwheel[col : col + CB, 2] = 255
510 | col = col + CB
511 | # BM
512 | colorwheel[col : col + BM, 2] = 255
513 | colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
514 | col = col + BM
515 | # MR
516 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
517 | colorwheel[col : col + MR, 0] = 255
518 | return colorwheel
519 |
520 |
521 | def _generate_color_palette(num_objects: int):
522 | palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
523 | return [tuple((i * palette) % 255) for i in range(num_objects)]
524 |
525 |
526 | def _log_api_usage_once(obj: Any) -> None:
527 |
528 | """
529 | Logs API usage(module and name) within an organization.
530 | In a large ecosystem, it's often useful to track the PyTorch and
531 | TorchVision APIs usage. This API provides the similar functionality to the
532 | logging module in the Python stdlib. It can be used for debugging purpose
533 | to log which methods are used and by default it is inactive, unless the user
534 | manually subscribes a logger via the `SetAPIUsageLogger method `_.
535 | Please note it is triggered only once for the same API call within a process.
536 | It does not collect any data from open-source users since it is no-op by default.
537 | For more information, please refer to
538 | * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
539 | * Logging policy: https://github.com/pytorch/vision/issues/5052;
540 |
541 | Args:
542 | obj (class instance or method): an object to extract info from.
543 | """
544 | if not obj.__module__.startswith("torchvision"):
545 | return
546 | name = obj.__class__.__name__
547 | if isinstance(obj, FunctionType):
548 | name = obj.__name__
549 | torch._C._log_api_usage_once(f"{obj.__module__}.{name}")
550 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/networks/__init__.py
--------------------------------------------------------------------------------
/networks/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | from torch.optim import lr_scheduler
6 |
7 |
8 | class BaseModel(nn.Module):
9 | def __init__(self, opt):
10 | super(BaseModel, self).__init__()
11 | self.opt = opt
12 | self.total_steps = 0
13 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
14 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
15 |
16 | def save_networks(self, save_filename):
17 | save_path = os.path.join(self.save_dir, save_filename)
18 |
19 | # serialize model and optimizer to dict
20 | state_dict = {
21 | 'model': self.model.state_dict(),
22 | 'optimizer' : self.optimizer.state_dict(),
23 | 'total_steps' : self.total_steps,
24 | }
25 |
26 | torch.save(state_dict, save_path)
27 |
28 |
29 | def eval(self):
30 | self.model.eval()
31 |
32 | def test(self):
33 | with torch.no_grad():
34 | self.forward()
35 |
36 |
37 | def init_weights(net, init_type='normal', gain=0.02):
38 | def init_func(m):
39 | classname = m.__class__.__name__
40 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
41 | if init_type == 'normal':
42 | init.normal_(m.weight.data, 0.0, gain)
43 | elif init_type == 'xavier':
44 | init.xavier_normal_(m.weight.data, gain=gain)
45 | elif init_type == 'kaiming':
46 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
47 | elif init_type == 'orthogonal':
48 | init.orthogonal_(m.weight.data, gain=gain)
49 | else:
50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51 | if hasattr(m, 'bias') and m.bias is not None:
52 | init.constant_(m.bias.data, 0.0)
53 | elif classname.find('BatchNorm2d') != -1:
54 | init.normal_(m.weight.data, 1.0, gain)
55 | init.constant_(m.bias.data, 0.0)
56 |
57 | print('initialize network with %s' % init_type)
58 | net.apply(init_func)
59 |
--------------------------------------------------------------------------------
/networks/lpf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, Adobe Inc. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
4 | # 4.0 International Public License. To view a copy of this license, visit
5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
6 |
7 | import torch
8 | import torch.nn.parallel
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from IPython import embed
13 |
14 | class Downsample(nn.Module):
15 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
16 | super(Downsample, self).__init__()
17 | self.filt_size = filt_size
18 | self.pad_off = pad_off
19 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
20 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
21 | self.stride = stride
22 | self.off = int((self.stride-1)/2.)
23 | self.channels = channels
24 |
25 | # print('Filter size [%i]'%filt_size)
26 | if(self.filt_size==1):
27 | a = np.array([1.,])
28 | elif(self.filt_size==2):
29 | a = np.array([1., 1.])
30 | elif(self.filt_size==3):
31 | a = np.array([1., 2., 1.])
32 | elif(self.filt_size==4):
33 | a = np.array([1., 3., 3., 1.])
34 | elif(self.filt_size==5):
35 | a = np.array([1., 4., 6., 4., 1.])
36 | elif(self.filt_size==6):
37 | a = np.array([1., 5., 10., 10., 5., 1.])
38 | elif(self.filt_size==7):
39 | a = np.array([1., 6., 15., 20., 15., 6., 1.])
40 |
41 | filt = torch.Tensor(a[:,None]*a[None,:])
42 | filt = filt/torch.sum(filt)
43 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
44 |
45 | self.pad = get_pad_layer(pad_type)(self.pad_sizes)
46 |
47 | def forward(self, inp):
48 | if(self.filt_size==1):
49 | if(self.pad_off==0):
50 | return inp[:,:,::self.stride,::self.stride]
51 | else:
52 | return self.pad(inp)[:,:,::self.stride,::self.stride]
53 | else:
54 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
55 |
56 | def get_pad_layer(pad_type):
57 | if(pad_type in ['refl','reflect']):
58 | PadLayer = nn.ReflectionPad2d
59 | elif(pad_type in ['repl','replicate']):
60 | PadLayer = nn.ReplicationPad2d
61 | elif(pad_type=='zero'):
62 | PadLayer = nn.ZeroPad2d
63 | else:
64 | print('Pad type [%s] not recognized'%pad_type)
65 | return PadLayer
66 |
67 |
68 | class Downsample1D(nn.Module):
69 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
70 | super(Downsample1D, self).__init__()
71 | self.filt_size = filt_size
72 | self.pad_off = pad_off
73 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
74 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
75 | self.stride = stride
76 | self.off = int((self.stride - 1) / 2.)
77 | self.channels = channels
78 |
79 | # print('Filter size [%i]' % filt_size)
80 | if(self.filt_size == 1):
81 | a = np.array([1., ])
82 | elif(self.filt_size == 2):
83 | a = np.array([1., 1.])
84 | elif(self.filt_size == 3):
85 | a = np.array([1., 2., 1.])
86 | elif(self.filt_size == 4):
87 | a = np.array([1., 3., 3., 1.])
88 | elif(self.filt_size == 5):
89 | a = np.array([1., 4., 6., 4., 1.])
90 | elif(self.filt_size == 6):
91 | a = np.array([1., 5., 10., 10., 5., 1.])
92 | elif(self.filt_size == 7):
93 | a = np.array([1., 6., 15., 20., 15., 6., 1.])
94 |
95 | filt = torch.Tensor(a)
96 | filt = filt / torch.sum(filt)
97 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
98 |
99 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
100 |
101 | def forward(self, inp):
102 | if(self.filt_size == 1):
103 | if(self.pad_off == 0):
104 | return inp[:, :, ::self.stride]
105 | else:
106 | return self.pad(inp)[:, :, ::self.stride]
107 | else:
108 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
109 |
110 |
111 | def get_pad_layer_1d(pad_type):
112 | if(pad_type in ['refl', 'reflect']):
113 | PadLayer = nn.ReflectionPad1d
114 | elif(pad_type in ['repl', 'replicate']):
115 | PadLayer = nn.ReplicationPad1d
116 | elif(pad_type == 'zero'):
117 | PadLayer = nn.ZeroPad1d
118 | else:
119 | print('Pad type [%s] not recognized' % pad_type)
120 | return PadLayer
121 |
--------------------------------------------------------------------------------
/networks/resnet_lpf.py:
--------------------------------------------------------------------------------
1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
2 | # Copyright (c) 2017 Torch Contributors.
3 | # The Pytorch examples are available under the BSD 3-Clause License.
4 | #
5 | # ==========================================================================================
6 | #
7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
11 | #
12 | # ==========================================================================================
13 | #
14 | # BSD-3 License
15 | #
16 | # Redistribution and use in source and binary forms, with or without
17 | # modification, are permitted provided that the following conditions are met:
18 | #
19 | # * Redistributions of source code must retain the above copyright notice, this
20 | # list of conditions and the following disclaimer.
21 | #
22 | # * Redistributions in binary form must reproduce the above copyright notice,
23 | # this list of conditions and the following disclaimer in the documentation
24 | # and/or other materials provided with the distribution.
25 | #
26 | # * Neither the name of the copyright holder nor the names of its
27 | # contributors may be used to endorse or promote products derived from
28 | # this software without specific prior written permission.
29 | #
30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 |
40 | import torch.nn as nn
41 | import torch.utils.model_zoo as model_zoo
42 | from .lpf import *
43 |
44 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
45 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
46 |
47 |
48 | # model_urls = {
49 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
50 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
51 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
52 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
53 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
54 | # }
55 |
56 |
57 | def conv3x3(in_planes, out_planes, stride=1, groups=1):
58 | """3x3 convolution with padding"""
59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
60 | padding=1, groups=groups, bias=False)
61 |
62 | def conv1x1(in_planes, out_planes, stride=1):
63 | """1x1 convolution"""
64 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
65 |
66 | class BasicBlock(nn.Module):
67 | expansion = 1
68 |
69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1):
70 | super(BasicBlock, self).__init__()
71 | if norm_layer is None:
72 | norm_layer = nn.BatchNorm2d
73 | if groups != 1:
74 | raise ValueError('BasicBlock only supports groups=1')
75 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
76 | self.conv1 = conv3x3(inplanes, planes)
77 | self.bn1 = norm_layer(planes)
78 | self.relu = nn.ReLU(inplace=True)
79 | if(stride==1):
80 | self.conv2 = conv3x3(planes,planes)
81 | else:
82 | self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes),
83 | conv3x3(planes, planes),)
84 | self.bn2 = norm_layer(planes)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | identity = x
90 |
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv2(out)
96 | out = self.bn2(out)
97 |
98 | if self.downsample is not None:
99 | identity = self.downsample(x)
100 |
101 | out += identity
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 |
107 | class Bottleneck(nn.Module):
108 | expansion = 4
109 |
110 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1):
111 | super(Bottleneck, self).__init__()
112 | if norm_layer is None:
113 | norm_layer = nn.BatchNorm2d
114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
115 | self.conv1 = conv1x1(inplanes, planes)
116 | self.bn1 = norm_layer(planes)
117 | self.conv2 = conv3x3(planes, planes, groups) # stride moved
118 | self.bn2 = norm_layer(planes)
119 | if(stride==1):
120 | self.conv3 = conv1x1(planes, planes * self.expansion)
121 | else:
122 | self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes),
123 | conv1x1(planes, planes * self.expansion))
124 | self.bn3 = norm_layer(planes * self.expansion)
125 | self.relu = nn.ReLU(inplace=True)
126 | self.downsample = downsample
127 | self.stride = stride
128 |
129 | def forward(self, x):
130 | identity = x
131 |
132 | out = self.conv1(x)
133 | out = self.bn1(out)
134 | out = self.relu(out)
135 |
136 | out = self.conv2(out)
137 | out = self.bn2(out)
138 | out = self.relu(out)
139 |
140 | out = self.conv3(out)
141 | out = self.bn3(out)
142 |
143 | if self.downsample is not None:
144 | identity = self.downsample(x)
145 |
146 | out += identity
147 | out = self.relu(out)
148 |
149 | return out
150 |
151 |
152 | class ResNet(nn.Module):
153 |
154 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
155 | groups=1, width_per_group=64, norm_layer=None, filter_size=1, pool_only=True):
156 | super(ResNet, self).__init__()
157 | if norm_layer is None:
158 | norm_layer = nn.BatchNorm2d
159 | planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
160 | self.inplanes = planes[0]
161 |
162 | if(pool_only):
163 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False)
164 | else:
165 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False)
166 | self.bn1 = norm_layer(planes[0])
167 | self.relu = nn.ReLU(inplace=True)
168 |
169 | if(pool_only):
170 | self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1),
171 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])])
172 | else:
173 | self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]),
174 | nn.MaxPool2d(kernel_size=2, stride=1),
175 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])])
176 |
177 | self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer)
178 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
179 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
180 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
182 | self.fc = nn.Linear(planes[3] * block.expansion, num_classes)
183 |
184 | for m in self.modules():
185 | if isinstance(m, nn.Conv2d):
186 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None):
187 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
188 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
189 | else:
190 | print('Not initializing')
191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192 | nn.init.constant_(m.weight, 1)
193 | nn.init.constant_(m.bias, 0)
194 |
195 | # Zero-initialize the last BN in each residual branch,
196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
198 | if zero_init_residual:
199 | for m in self.modules():
200 | if isinstance(m, Bottleneck):
201 | nn.init.constant_(m.bn3.weight, 0)
202 | elif isinstance(m, BasicBlock):
203 | nn.init.constant_(m.bn2.weight, 0)
204 |
205 | def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1):
206 | if norm_layer is None:
207 | norm_layer = nn.BatchNorm2d
208 | downsample = None
209 | if stride != 1 or self.inplanes != planes * block.expansion:
210 | # downsample = nn.Sequential(
211 | # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size),
212 | # norm_layer(planes * block.expansion),
213 | # )
214 |
215 | downsample = [Downsample(filt_size=filter_size, stride=stride, channels=self.inplanes),] if(stride !=1) else []
216 | downsample += [conv1x1(self.inplanes, planes * block.expansion, 1),
217 | norm_layer(planes * block.expansion)]
218 | # print(downsample)
219 | downsample = nn.Sequential(*downsample)
220 |
221 | layers = []
222 | layers.append(block(self.inplanes, planes, stride, downsample, groups, norm_layer, filter_size=filter_size))
223 | self.inplanes = planes * block.expansion
224 | for _ in range(1, blocks):
225 | layers.append(block(self.inplanes, planes, groups=groups, norm_layer=norm_layer, filter_size=filter_size))
226 |
227 | return nn.Sequential(*layers)
228 |
229 | def forward(self, x):
230 | x = self.conv1(x)
231 | x = self.bn1(x)
232 | x = self.relu(x)
233 | x = self.maxpool(x)
234 |
235 | x = self.layer1(x)
236 | x = self.layer2(x)
237 | x = self.layer3(x)
238 | x = self.layer4(x)
239 |
240 | x = self.avgpool(x)
241 | x = x.view(x.size(0), -1)
242 | x = self.fc(x)
243 |
244 | return x
245 |
246 |
247 | def resnet18(pretrained=False, filter_size=1, pool_only=True, **kwargs):
248 | """Constructs a ResNet-18 model.
249 | Args:
250 | pretrained (bool): If True, returns a model pre-trained on ImageNet
251 | """
252 | model = ResNet(BasicBlock, [2, 2, 2, 2], filter_size=filter_size, pool_only=pool_only, **kwargs)
253 | if pretrained:
254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
255 | return model
256 |
257 |
258 | def resnet34(pretrained=False, filter_size=1, pool_only=True, **kwargs):
259 | """Constructs a ResNet-34 model.
260 | Args:
261 | pretrained (bool): If True, returns a model pre-trained on ImageNet
262 | """
263 | model = ResNet(BasicBlock, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
264 | if pretrained:
265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
266 | return model
267 |
268 |
269 | def resnet50(pretrained=False, filter_size=1, pool_only=True, **kwargs):
270 | """Constructs a ResNet-50 model.
271 | Args:
272 | pretrained (bool): If True, returns a model pre-trained on ImageNet
273 | """
274 | model = ResNet(Bottleneck, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
275 | if pretrained:
276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
277 | return model
278 |
279 |
280 | def resnet101(pretrained=False, filter_size=1, pool_only=True, **kwargs):
281 | """Constructs a ResNet-101 model.
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on ImageNet
284 | """
285 | model = ResNet(Bottleneck, [3, 4, 23, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
286 | if pretrained:
287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
288 | return model
289 |
290 |
291 | def resnet152(pretrained=False, filter_size=1, pool_only=True, **kwargs):
292 | """Constructs a ResNet-152 model.
293 | Args:
294 | pretrained (bool): If True, returns a model pre-trained on ImageNet
295 | """
296 | model = ResNet(Bottleneck, [3, 8, 36, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
297 | if pretrained:
298 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
299 | return model
300 |
301 |
302 | def resnext50_32x4d(pretrained=False, filter_size=1, pool_only=True, **kwargs):
303 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs)
304 | # if pretrained:
305 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
306 | return model
307 |
308 |
309 | def resnext101_32x8d(pretrained=False, filter_size=1, pool_only=True, **kwargs):
310 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs)
311 | # if pretrained:
312 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
313 | return model
314 |
--------------------------------------------------------------------------------
/networks/trainer.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | from networks.base_model import BaseModel, init_weights
5 | import sys
6 | from models import get_model
7 |
8 | class Trainer(BaseModel):
9 | def name(self):
10 | return 'Trainer'
11 |
12 | def __init__(self, opt):
13 | super(Trainer, self).__init__(opt)
14 | self.opt = opt
15 | self.model = get_model(opt.arch)
16 | torch.nn.init.normal_(self.model.fc.weight.data, 0.0, opt.init_gain)
17 |
18 | if opt.fix_backbone:
19 | params = []
20 | for name, p in self.model.named_parameters():
21 | if name=="fc.weight" or name=="fc.bias":
22 | params.append(p)
23 | else:
24 | p.requires_grad = False
25 | else:
26 | print("Your backbone is not fixed. Are you sure you want to proceed? If this is a mistake, enable the --fix_backbone command during training and rerun")
27 | import time
28 | time.sleep(3)
29 | params = self.model.parameters()
30 |
31 |
32 |
33 | if opt.optim == 'adam':
34 | self.optimizer = torch.optim.AdamW(params, lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay)
35 | elif opt.optim == 'sgd':
36 | self.optimizer = torch.optim.SGD(params, lr=opt.lr, momentum=0.0, weight_decay=opt.weight_decay)
37 | else:
38 | raise ValueError("optim should be [adam, sgd]")
39 |
40 | self.loss_fn = nn.BCEWithLogitsLoss()
41 |
42 | self.model.to(opt.gpu_ids[0])
43 |
44 |
45 | def adjust_learning_rate(self, min_lr=1e-6):
46 | for param_group in self.optimizer.param_groups:
47 | param_group['lr'] /= 10.
48 | if param_group['lr'] < min_lr:
49 | return False
50 | return True
51 |
52 |
53 | def set_input(self, input):
54 | self.input = input[0].to(self.device)
55 | self.label = input[1].to(self.device).float()
56 |
57 |
58 | def forward(self):
59 | self.output = self.model(self.input)
60 | self.output = self.output.view(-1).unsqueeze(1)
61 |
62 |
63 | def get_loss(self):
64 | return self.loss_fn(self.output.squeeze(1), self.label)
65 |
66 | def optimize_parameters(self):
67 | self.forward()
68 | self.loss = self.loss_fn(self.output.squeeze(1), self.label)
69 | self.optimizer.zero_grad()
70 | self.loss.backward()
71 | self.optimizer.step()
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import util
4 | import torch
5 |
6 |
7 | class BaseOptions():
8 | def __init__(self):
9 | self.initialized = False
10 |
11 | def initialize(self, parser):
12 | parser.add_argument('--mode', default='binary')
13 | parser.add_argument('--arch', type=str, default='res50', help='see my_models/__init__.py')
14 | parser.add_argument('--fix_backbone', action='store_true')
15 |
16 | # data augmentation
17 | parser.add_argument('--rz_interp', default='bilinear')
18 | parser.add_argument('--blur_prob', type=float, default=0.5)
19 | parser.add_argument('--blur_sig', default='0.0,3.0')
20 | parser.add_argument('--jpg_prob', type=float, default=0.5)
21 | parser.add_argument('--jpg_method', default='cv2,pil')
22 | parser.add_argument('--jpg_qual', default='30,100')
23 |
24 |
25 | parser.add_argument('--real_list_path', default=None, help='only used if data_mode==ours: path for the list of real images, which should contain train.pickle and val.pickle')
26 | parser.add_argument('--fake_list_path', default=None, help='only used if data_mode==ours: path for the list of fake images, which should contain train.pickle and val.pickle')
27 | parser.add_argument('--wang2020_data_path', default=None, help='only used if data_mode==wang2020 it should contain train and test folders')
28 | parser.add_argument('--data_mode', default='ours', help='wang2020 or ours')
29 | parser.add_argument('--data_label', default='train', help='label to decide whether train or validation dataset')
30 | parser.add_argument('--weight_decay', type=float, default=0.0, help='loss weight for l2 reg')
31 |
32 | parser.add_argument('--class_bal', action='store_true') # what is this ?
33 | parser.add_argument('--batch_size', type=int, default=256, help='input batch size')
34 | parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size')
35 | parser.add_argument('--cropSize', type=int, default=224, help='then crop to this size')
36 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
37 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
38 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
39 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
40 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
41 | parser.add_argument('--resize_or_crop', type=str, default='scale_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]')
42 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
43 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
44 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
45 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}')
46 | self.initialized = True
47 | return parser
48 |
49 | def gather_options(self):
50 | # initialize parser with basic options
51 | if not self.initialized:
52 | parser = argparse.ArgumentParser(
53 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
54 | parser = self.initialize(parser)
55 |
56 | # get the basic options
57 | opt, _ = parser.parse_known_args()
58 | self.parser = parser
59 |
60 | return parser.parse_args()
61 |
62 | def print_options(self, opt):
63 | message = ''
64 | message += '----------------- Options ---------------\n'
65 | for k, v in sorted(vars(opt).items()):
66 | comment = ''
67 | default = self.parser.get_default(k)
68 | if v != default:
69 | comment = '\t[default: %s]' % str(default)
70 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
71 | message += '----------------- End -------------------'
72 | print(message)
73 |
74 | # save to the disk
75 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
76 | util.mkdirs(expr_dir)
77 | file_name = os.path.join(expr_dir, 'opt.txt')
78 | with open(file_name, 'wt') as opt_file:
79 | opt_file.write(message)
80 | opt_file.write('\n')
81 |
82 | def parse(self, print_options=True):
83 |
84 | opt = self.gather_options()
85 | opt.isTrain = self.isTrain # train or test
86 |
87 | # process opt.suffix
88 | if opt.suffix:
89 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
90 | opt.name = opt.name + suffix
91 |
92 | if print_options:
93 | self.print_options(opt)
94 |
95 | # set gpu ids
96 | str_ids = opt.gpu_ids.split(',')
97 | opt.gpu_ids = []
98 | for str_id in str_ids:
99 | id = int(str_id)
100 | if id >= 0:
101 | opt.gpu_ids.append(id)
102 | if len(opt.gpu_ids) > 0:
103 | torch.cuda.set_device(opt.gpu_ids[0])
104 |
105 | # additional
106 | #opt.classes = opt.classes.split(',')
107 | opt.rz_interp = opt.rz_interp.split(',')
108 | opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')]
109 | opt.jpg_method = opt.jpg_method.split(',')
110 | opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')]
111 | if len(opt.jpg_qual) == 2:
112 | opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1))
113 | elif len(opt.jpg_qual) > 2:
114 | raise ValueError("Shouldn't have more than 2 values for --jpg_qual.")
115 |
116 | self.opt = opt
117 | return self.opt
118 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | parser.add_argument('--model_path')
8 | parser.add_argument('--no_resize', action='store_true')
9 | parser.add_argument('--no_crop', action='store_true')
10 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
11 |
12 | self.isTrain = False
13 | return parser
14 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | parser.add_argument('--earlystop_epoch', type=int, default=5)
8 | parser.add_argument('--data_aug', action='store_true', help='if specified, perform additional data augmentation (photometric, blurring, jpegging)')
9 | parser.add_argument('--optim', type=str, default='adam', help='optim to use [sgd, adam]')
10 | parser.add_argument('--new_optim', action='store_true', help='new optimizer instead of loading the optim state')
11 | parser.add_argument('--loss_freq', type=int, default=400, help='frequency of showing loss on tensorboard')
12 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
13 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
14 | parser.add_argument('--last_epoch', type=int, default=-1, help='starting epoch count for scheduler intialization')
15 | parser.add_argument('--train_split', type=str, default='train', help='train, val, test, etc')
16 | parser.add_argument('--val_split', type=str, default='val', help='train, val, test, etc')
17 | parser.add_argument('--niter', type=int, default=100, help='total epoches')
18 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
19 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
20 |
21 | self.isTrain = True
22 | return parser
23 |
--------------------------------------------------------------------------------
/pretrained_weights/fc_weights.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/pretrained_weights/fc_weights.pth
--------------------------------------------------------------------------------
/resources/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisconsinAIVision/UniversalFakeDetect/76a0e3e60a8a06458707a625d269ba815a2e5919/resources/teaser.png
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python3 validate.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --result_folder=clip_vitl14
2 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from tensorboardX import SummaryWriter
4 |
5 | from validate import validate
6 | from data import create_dataloader
7 | from earlystop import EarlyStopping
8 | from networks.trainer import Trainer
9 | from options.train_options import TrainOptions
10 |
11 |
12 | """Currently assumes jpg_prob, blur_prob 0 or 1"""
13 | def get_val_opt():
14 | val_opt = TrainOptions().parse(print_options=False)
15 | val_opt.isTrain = False
16 | val_opt.no_resize = False
17 | val_opt.no_crop = False
18 | val_opt.serial_batches = True
19 | val_opt.data_label = 'val'
20 | val_opt.jpg_method = ['pil']
21 | if len(val_opt.blur_sig) == 2:
22 | b_sig = val_opt.blur_sig
23 | val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
24 | if len(val_opt.jpg_qual) != 1:
25 | j_qual = val_opt.jpg_qual
26 | val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
27 |
28 | return val_opt
29 |
30 |
31 |
32 | if __name__ == '__main__':
33 | opt = TrainOptions().parse()
34 | val_opt = get_val_opt()
35 |
36 | model = Trainer(opt)
37 |
38 | data_loader = create_dataloader(opt)
39 | val_loader = create_dataloader(val_opt)
40 |
41 | train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train"))
42 | val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val"))
43 |
44 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True)
45 | start_time = time.time()
46 | print ("Length of data loader: %d" %(len(data_loader)))
47 | for epoch in range(opt.niter):
48 |
49 | for i, data in enumerate(data_loader):
50 | model.total_steps += 1
51 |
52 | model.set_input(data)
53 | model.optimize_parameters()
54 |
55 | if model.total_steps % opt.loss_freq == 0:
56 | print("Train loss: {} at step: {}".format(model.loss, model.total_steps))
57 | train_writer.add_scalar('loss', model.loss, model.total_steps)
58 | print("Iter time: ", ((time.time()-start_time)/model.total_steps) )
59 |
60 | if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters
61 | model.save_networks('model_iters_%s.pth' % model.total_steps)
62 |
63 | if epoch % opt.save_epoch_freq == 0:
64 | print('saving the model at the end of epoch %d' % (epoch))
65 | model.save_networks( 'model_epoch_best.pth' )
66 | model.save_networks( 'model_epoch_%s.pth' % epoch )
67 |
68 | # Validation
69 | model.eval()
70 | ap, r_acc, f_acc, acc = validate(model.model, val_loader)
71 | val_writer.add_scalar('accuracy', acc, model.total_steps)
72 | val_writer.add_scalar('ap', ap, model.total_steps)
73 | print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
74 |
75 | early_stopping(acc, model)
76 | if early_stopping.early_stop:
77 | cont_train = model.adjust_learning_rate()
78 | if cont_train:
79 | print("Learning rate dropped by 10, continue training...")
80 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.002, verbose=True)
81 | else:
82 | print("Early stopping.")
83 | break
84 | model.train()
85 |
86 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from ast import arg
3 | import os
4 | import csv
5 | import torch
6 | import torchvision.transforms as transforms
7 | import torch.utils.data
8 | import numpy as np
9 | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
10 | from torch.utils.data import Dataset
11 | import sys
12 | from models import get_model
13 | from PIL import Image
14 | import pickle
15 | from tqdm import tqdm
16 | from io import BytesIO
17 | from copy import deepcopy
18 | from dataset_paths import DATASET_PATHS
19 | import random
20 | import shutil
21 | from scipy.ndimage.filters import gaussian_filter
22 |
23 | SEED = 0
24 | def set_seed():
25 | torch.manual_seed(SEED)
26 | torch.cuda.manual_seed(SEED)
27 | np.random.seed(SEED)
28 | random.seed(SEED)
29 |
30 |
31 | MEAN = {
32 | "imagenet":[0.485, 0.456, 0.406],
33 | "clip":[0.48145466, 0.4578275, 0.40821073]
34 | }
35 |
36 | STD = {
37 | "imagenet":[0.229, 0.224, 0.225],
38 | "clip":[0.26862954, 0.26130258, 0.27577711]
39 | }
40 |
41 |
42 |
43 |
44 |
45 | def find_best_threshold(y_true, y_pred):
46 | "We assume first half is real 0, and the second half is fake 1"
47 |
48 | N = y_true.shape[0]
49 |
50 | if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case
51 | return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2
52 |
53 | best_acc = 0
54 | best_thres = 0
55 | for thres in y_pred:
56 | temp = deepcopy(y_pred)
57 | temp[temp>=thres] = 1
58 | temp[temp= best_acc:
62 | best_thres = thres
63 | best_acc = acc
64 |
65 | return best_thres
66 |
67 |
68 |
69 | def png2jpg(img, quality):
70 | out = BytesIO()
71 | img.save(out, format='jpeg', quality=quality) # ranging from 0-95, 75 is default
72 | img = Image.open(out)
73 | # load from memory before ByteIO closes
74 | img = np.array(img)
75 | out.close()
76 | return Image.fromarray(img)
77 |
78 |
79 | def gaussian_blur(img, sigma):
80 | img = np.array(img)
81 |
82 | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma)
83 | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma)
84 | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma)
85 |
86 | return Image.fromarray(img)
87 |
88 |
89 |
90 | def calculate_acc(y_true, y_pred, thres):
91 | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres)
92 | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres)
93 | acc = accuracy_score(y_true, y_pred > thres)
94 | return r_acc, f_acc, acc
95 |
96 |
97 | def validate(model, loader, find_thres=False):
98 |
99 | with torch.no_grad():
100 | y_true, y_pred = [], []
101 | print ("Length of dataset: %d" %(len(loader)))
102 | for img, label in loader:
103 | in_tens = img.cuda()
104 |
105 | y_pred.extend(model(in_tens).sigmoid().flatten().tolist())
106 | y_true.extend(label.flatten().tolist())
107 |
108 | y_true, y_pred = np.array(y_true), np.array(y_pred)
109 |
110 | # ================== save this if you want to plot the curves =========== #
111 | # torch.save( torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' )
112 | # exit()
113 | # =================================================================== #
114 |
115 | # Get AP
116 | ap = average_precision_score(y_true, y_pred)
117 |
118 | # Acc based on 0.5
119 | r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5)
120 | if not find_thres:
121 | return ap, r_acc0, f_acc0, acc0
122 |
123 |
124 | # Acc based on the best thres
125 | best_thres = find_best_threshold(y_true, y_pred)
126 | r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres)
127 |
128 | return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres
129 |
130 |
131 |
132 |
133 |
134 |
135 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
136 |
137 |
138 |
139 |
140 | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]):
141 | out = []
142 | for r, d, f in os.walk(rootdir):
143 | for file in f:
144 | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)):
145 | out.append(os.path.join(r, file))
146 | return out
147 |
148 |
149 | def get_list(path, must_contain=''):
150 | if ".pickle" in path:
151 | with open(path, 'rb') as f:
152 | image_list = pickle.load(f)
153 | image_list = [ item for item in image_list if must_contain in item ]
154 | else:
155 | image_list = recursively_read(path, must_contain)
156 | return image_list
157 |
158 |
159 |
160 |
161 |
162 | class RealFakeDataset(Dataset):
163 | def __init__(self, real_path,
164 | fake_path,
165 | data_mode,
166 | max_sample,
167 | arch,
168 | jpeg_quality=None,
169 | gaussian_sigma=None):
170 |
171 | assert data_mode in ["wang2020", "ours"]
172 | self.jpeg_quality = jpeg_quality
173 | self.gaussian_sigma = gaussian_sigma
174 |
175 | # = = = = = = data path = = = = = = = = = #
176 | if type(real_path) == str and type(fake_path) == str:
177 | real_list, fake_list = self.read_path(real_path, fake_path, data_mode, max_sample)
178 | else:
179 | real_list = []
180 | fake_list = []
181 | for real_p, fake_p in zip(real_path, fake_path):
182 | real_l, fake_l = self.read_path(real_p, fake_p, data_mode, max_sample)
183 | real_list += real_l
184 | fake_list += fake_l
185 |
186 | self.total_list = real_list + fake_list
187 |
188 |
189 | # = = = = = = label = = = = = = = = = #
190 |
191 | self.labels_dict = {}
192 | for i in real_list:
193 | self.labels_dict[i] = 0
194 | for i in fake_list:
195 | self.labels_dict[i] = 1
196 |
197 | stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip"
198 | self.transform = transforms.Compose([
199 | transforms.CenterCrop(224),
200 | transforms.ToTensor(),
201 | transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ),
202 | ])
203 |
204 |
205 | def read_path(self, real_path, fake_path, data_mode, max_sample):
206 |
207 | if data_mode == 'wang2020':
208 | real_list = get_list(real_path, must_contain='0_real')
209 | fake_list = get_list(fake_path, must_contain='1_fake')
210 | else:
211 | real_list = get_list(real_path)
212 | fake_list = get_list(fake_path)
213 |
214 |
215 | if max_sample is not None:
216 | if (max_sample > len(real_list)) or (max_sample > len(fake_list)):
217 | max_sample = 100
218 | print("not enough images, max_sample falling to 100")
219 | random.shuffle(real_list)
220 | random.shuffle(fake_list)
221 | real_list = real_list[0:max_sample]
222 | fake_list = fake_list[0:max_sample]
223 |
224 | assert len(real_list) == len(fake_list)
225 |
226 | return real_list, fake_list
227 |
228 |
229 |
230 | def __len__(self):
231 | return len(self.total_list)
232 |
233 | def __getitem__(self, idx):
234 |
235 | img_path = self.total_list[idx]
236 |
237 | label = self.labels_dict[img_path]
238 | img = Image.open(img_path).convert("RGB")
239 |
240 | if self.gaussian_sigma is not None:
241 | img = gaussian_blur(img, self.gaussian_sigma)
242 | if self.jpeg_quality is not None:
243 | img = png2jpg(img, self.jpeg_quality)
244 |
245 | img = self.transform(img)
246 | return img, label
247 |
248 |
249 |
250 |
251 |
252 | if __name__ == '__main__':
253 |
254 |
255 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
256 | parser.add_argument('--real_path', type=str, default=None, help='dir name or a pickle')
257 | parser.add_argument('--fake_path', type=str, default=None, help='dir name or a pickle')
258 | parser.add_argument('--data_mode', type=str, default=None, help='wang2020 or ours')
259 | parser.add_argument('--max_sample', type=int, default=1000, help='only check this number of images for both fake/real')
260 |
261 | parser.add_argument('--arch', type=str, default='res50')
262 | parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth')
263 |
264 | parser.add_argument('--result_folder', type=str, default='result', help='')
265 | parser.add_argument('--batch_size', type=int, default=128)
266 |
267 | parser.add_argument('--jpeg_quality', type=int, default=None, help="100, 90, 80, ... 30. Used to test robustness of our model. Not apply if None")
268 | parser.add_argument('--gaussian_sigma', type=int, default=None, help="0,1,2,3,4. Used to test robustness of our model. Not apply if None")
269 |
270 |
271 | opt = parser.parse_args()
272 |
273 |
274 | if os.path.exists(opt.result_folder):
275 | shutil.rmtree(opt.result_folder)
276 | os.makedirs(opt.result_folder)
277 |
278 | model = get_model(opt.arch)
279 | state_dict = torch.load(opt.ckpt, map_location='cpu')
280 | model.fc.load_state_dict(state_dict)
281 | print ("Model loaded..")
282 | model.eval()
283 | model.cuda()
284 |
285 | if (opt.real_path == None) or (opt.fake_path == None) or (opt.data_mode == None):
286 | dataset_paths = DATASET_PATHS
287 | else:
288 | dataset_paths = [ dict(real_path=opt.real_path, fake_path=opt.fake_path, data_mode=opt.data_mode) ]
289 |
290 |
291 |
292 | for dataset_path in (dataset_paths):
293 | set_seed()
294 |
295 | dataset = RealFakeDataset( dataset_path['real_path'],
296 | dataset_path['fake_path'],
297 | dataset_path['data_mode'],
298 | opt.max_sample,
299 | opt.arch,
300 | jpeg_quality=opt.jpeg_quality,
301 | gaussian_sigma=opt.gaussian_sigma,
302 | )
303 |
304 | loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=4)
305 | ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, loader, find_thres=True)
306 |
307 | with open( os.path.join(opt.result_folder,'ap.txt'), 'a') as f:
308 | f.write(dataset_path['key']+': ' + str(round(ap*100, 2))+'\n' )
309 |
310 | with open( os.path.join(opt.result_folder,'acc0.txt'), 'a') as f:
311 | f.write(dataset_path['key']+': ' + str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' )
312 |
313 |
--------------------------------------------------------------------------------