├── .gitignore ├── README.md ├── figures ├── pipeline.png └── results.png ├── networks ├── resnet.py ├── srm_conv.py └── ssp.py ├── options.py ├── requirements.txt ├── test.py ├── test.sh ├── train_val.py ├── train_val.sh └── utils ├── loss.py ├── patch.py ├── tdataloader.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | *.pth 4 | *.log 5 | *.json 6 | *.npy 7 | snapshot -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSP-AI-Generated-Image-Detection 2 | 3 | This is the official implementation for the following research paper: 4 | 5 | > **A Single Simple Patch is All You Need for AI-generated Image Detection** [[arxiv]](https://arxiv.org/pdf/2402.01123.pdf) 6 | > 7 | > Jiaxuan Chen, Jieteng Yao, and Li Niu
8 | 9 | Note that in the paper, we proposed Enhanced SSP (ESSP) to improve its robustness against blur and compression. Currently, we only release the code of SSP (the flowchart is shown in the following figure). The code of ESSP will be released soon. 10 | 11 |
12 | 13 |
14 | 15 | ## Environment Setup 16 | You can install the required packages by running the command: 17 | ```bash 18 | pip install -r requirements.txt 19 | ``` 20 | ## Dataset 21 | The training set and testing set used in the paper can be downloaded from [GenImage](https://github.com/GenImage-Dataset/GenImage). This dataset contains data from eight generators. 22 | After downloading the dataset, you need to specify the root path in the options. The dataset can be organized as follows: 23 | ```bash 24 | GenImage/ 25 | ├── imagenet_ai_0419_biggan 26 | ├── train 27 | ├── ai 28 | ├── nature 29 | ├── val 30 | ├── ai 31 | ├── nature 32 | └── imagenet_ai_0419_sdv4 33 | ├── train 34 | ├── ai 35 | ├── nature 36 | ├── val 37 | ├── ai 38 | ├── nature 39 | └── imagenet_ai_0419_vqdm 40 | ... 41 | └── imagenet_ai_0424_sdv5 42 | ... 43 | └── imagenet_ai_0424_wukong 44 | ... 45 | └── imagenet_ai_0508_adm 46 | ... 47 | └── imagenet_glide 48 | ... 49 | └── imagenet_midjourney 50 | ... 51 | ``` 52 | ## Training and Validation 53 | You can simply run the following script to train and validate your model: 54 | ```bash 55 | sh train_val.sh 56 | ``` 57 | ## Testing 58 | You can simply run the following script to test your model: 59 | ```bash 60 | sh test.sh 61 | ``` 62 | Our pretrained models on eight dataests can be downloaded from [Baidu Cloud](https://pan.baidu.com/s/1Wk2Cqeav_wVxPMPNy-zHZQ?pwd=bcmi) (code:bcmi) or [OneDrive](https://1drv.ms/f/s!Aq2pxrmMfMvRh29sp4zHSlbJRlP7?e=C3aHEp). 63 | 64 | ## Results on GenImage 65 | The results of ResNet50 baseline and our SSP method with different training and test subsets. In each slot, the left number is the result of ResNet50 and the right number is the result of our SSP. 66 |
67 | 68 |
69 | -------------------------------------------------------------------------------- /figures/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SSP-AI-Generated-Image-Detection/3af21c9a3085364465bd51806d7c8fb956d0726f/figures/pipeline.png -------------------------------------------------------------------------------- /figures/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SSP-AI-Generated-Image-Detection/3af21c9a3085364465bd51806d7c8fb956d0726f/figures/results.png -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch 4 | __all__ = ["ResNet", "resnet18", "resnet34", 5 | "resnet50", "resnet101", "resnet152"] 6 | 7 | 8 | model_urls = { 9 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 10 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 11 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 12 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 13 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super().__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super().__init__() 64 | self.conv1 = conv1x1(inplanes, planes) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = conv3x3(planes, planes, stride) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = conv1x1(planes, planes * self.expansion) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | identity = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | identity = self.downsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 99 | super().__init__() 100 | self.inplanes = 64 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 102 | stride=2, padding=3, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | nn.init.kaiming_normal_( 116 | m.weight, mode="fan_out", nonlinearity="relu") 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | # Zero-initialize the last BN in each residual branch, 122 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 123 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 124 | if zero_init_residual: 125 | for m in self.modules(): 126 | if isinstance(m, Bottleneck): 127 | nn.init.constant_(m.bn3.weight, 0) 128 | elif isinstance(m, BasicBlock): 129 | nn.init.constant_(m.bn2.weight, 0) 130 | 131 | def _make_layer(self, block, planes, blocks, stride=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | downsample = nn.Sequential( 135 | conv1x1(self.inplanes, planes * block.expansion, stride), 136 | nn.BatchNorm2d(planes * block.expansion), 137 | ) 138 | 139 | layers = [block(self.inplanes, planes, stride, downsample)] 140 | self.inplanes = planes * block.expansion 141 | layers.extend(block(self.inplanes, planes) for _ in range(1, blocks)) 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x, *args): 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | 155 | x = self.avgpool(x) 156 | x = x.view(x.size(0), -1) 157 | x = self.fc(x) 158 | 159 | return x 160 | 161 | 162 | def resnet18(pretrained=False, **kwargs): 163 | """Constructs a ResNet-18 model. 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"])) 170 | return model 171 | 172 | 173 | def resnet34(pretrained=False, **kwargs): 174 | """Constructs a ResNet-34 model. 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(model_zoo.load_url(model_urls["resnet34"])) 181 | return model 182 | 183 | 184 | def resnet50(pretrained=False, **kwargs): 185 | """Constructs a ResNet-50 model. 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | """ 189 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 190 | if pretrained: 191 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"])) 192 | return model 193 | 194 | 195 | def resnet101(pretrained=False, **kwargs): 196 | """Constructs a ResNet-101 model. 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | """ 200 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 201 | if pretrained: 202 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"])) 203 | return model 204 | 205 | 206 | def resnet152(pretrained=False, **kwargs): 207 | """Constructs a ResNet-152 model. 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls["resnet152"])) 214 | return model 215 | 216 | 217 | if __name__ == '__main__': 218 | net = resnet50(pretrained=True) 219 | print(net) 220 | -------------------------------------------------------------------------------- /networks/srm_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class SRMConv2d_simple(nn.Module): 8 | 9 | def __init__(self, inc=3, learnable=False): 10 | super(SRMConv2d_simple, self).__init__() 11 | self.truc = nn.Hardtanh(-3, 3) 12 | kernel = self._build_kernel(inc) # (3,3,5,5) 13 | self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) 14 | # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) 15 | 16 | def forward(self, x): 17 | ''' 18 | x: imgs (Batch, H, W, 3) 19 | ''' 20 | out = F.conv2d(x, self.kernel, stride=1, padding=2) 21 | out = self.truc(out) 22 | 23 | return out 24 | 25 | def _build_kernel(self, inc): 26 | # filter1: KB 27 | filter1 = [[0, 0, 0, 0, 0], 28 | [0, -1, 2, -1, 0], 29 | [0, 2, -4, 2, 0], 30 | [0, -1, 2, -1, 0], 31 | [0, 0, 0, 0, 0]] 32 | # filter2:KV 33 | filter2 = [[-1, 2, -2, 2, -1], 34 | [2, -6, 8, -6, 2], 35 | [-2, 8, -12, 8, -2], 36 | [2, -6, 8, -6, 2], 37 | [-1, 2, -2, 2, -1]] 38 | # filter3:hor 2rd 39 | filter3 = [[0, 0, 0, 0, 0], 40 | [0, 0, 0, 0, 0], 41 | [0, 1, -2, 1, 0], 42 | [0, 0, 0, 0, 0], 43 | [0, 0, 0, 0, 0]] 44 | 45 | filter1 = np.asarray(filter1, dtype=float) / 4. 46 | filter2 = np.asarray(filter2, dtype=float) / 12. 47 | filter3 = np.asarray(filter3, dtype=float) / 2. 48 | # statck the filters 49 | filters = [[filter1], # , filter1, filter1], 50 | [filter2], # , filter2, filter2], 51 | [filter3]] # , filter3, filter3]] # (3,3,5,5) 52 | filters = np.array(filters) 53 | filters = np.repeat(filters, inc, axis=1) 54 | filters = torch.FloatTensor(filters) # (3,3,5,5) 55 | return filters 56 | 57 | 58 | class SRMConv2d_Separate(nn.Module): 59 | 60 | def __init__(self, inc, outc, learnable=False): 61 | super(SRMConv2d_Separate, self).__init__() 62 | self.inc = inc 63 | self.truc = nn.Hardtanh(-3, 3) 64 | kernel = self._build_kernel(inc) # (3,3,5,5) 65 | self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) 66 | # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) 67 | self.out_conv = nn.Sequential( 68 | nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False), 69 | nn.BatchNorm2d(outc), 70 | nn.ReLU(inplace=True) 71 | ) 72 | 73 | for ly in self.out_conv.children(): 74 | if isinstance(ly, nn.Conv2d): 75 | nn.init.kaiming_normal_(ly.weight, a=1) 76 | 77 | def forward(self, x): 78 | ''' 79 | x: imgs (Batch, H, W, 3) 80 | ''' 81 | out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc) 82 | out = self.truc(out) 83 | out = self.out_conv(out) 84 | 85 | return out 86 | 87 | def _build_kernel(self, inc): 88 | # filter1: KB 89 | filter1 = [[0, 0, 0, 0, 0], 90 | [0, -1, 2, -1, 0], 91 | [0, 2, -4, 2, 0], 92 | [0, -1, 2, -1, 0], 93 | [0, 0, 0, 0, 0]] 94 | # filter2:KV 95 | filter2 = [[-1, 2, -2, 2, -1], 96 | [2, -6, 8, -6, 2], 97 | [-2, 8, -12, 8, -2], 98 | [2, -6, 8, -6, 2], 99 | [-1, 2, -2, 2, -1]] 100 | # # filter3:hor 2rd 101 | filter3 = [[0, 0, 0, 0, 0], 102 | [0, 0, 0, 0, 0], 103 | [0, 1, -2, 1, 0], 104 | [0, 0, 0, 0, 0], 105 | [0, 0, 0, 0, 0]] 106 | 107 | filter1 = np.asarray(filter1, dtype=float) / 4. 108 | filter2 = np.asarray(filter2, dtype=float) / 12. 109 | filter3 = np.asarray(filter3, dtype=float) / 2. 110 | # statck the filters 111 | filters = [[filter1], # , filter1, filter1], 112 | [filter2], # , filter2, filter2], 113 | [filter3]] # , filter3, filter3]] # (3,3,5,5) 114 | filters = np.array(filters) 115 | # filters = np.repeat(filters, inc, axis=1) 116 | filters = np.repeat(filters, inc, axis=0) 117 | filters = torch.FloatTensor(filters) # (3,3,5,5) 118 | # print(filters.size()) 119 | return filters 120 | 121 | 122 | if __name__ == '__main__': 123 | x = torch.rand(1, 3, 224, 224) 124 | srm = SRMConv2d_simple() 125 | output = srm(x) 126 | output = np.array(output) 127 | print(output.shape) -------------------------------------------------------------------------------- /networks/ssp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from networks.resnet import resnet50 4 | from networks.srm_conv import SRMConv2d_simple 5 | import torch.nn.functional as F 6 | 7 | 8 | class ssp(nn.Module): 9 | def __init__(self, pretrain=True): 10 | super().__init__() 11 | self.srm = SRMConv2d_simple() 12 | self.disc = resnet50(pretrained=True) 13 | self.disc.fc = nn.Linear(2048, 1) 14 | 15 | def forward(self, x): 16 | x = F.interpolate(x, (256, 256), mode='bilinear') 17 | x = self.srm(x) 18 | x = self.disc(x) 19 | return x 20 | 21 | 22 | if __name__ == '__main__': 23 | model = ssp(pretrain=True) 24 | print(model) 25 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | 6 | class TrainOptions(): 7 | def __init__(self): 8 | self.initialized = False 9 | 10 | def initialize(self, parser): 11 | # data augmentation 12 | parser.add_argument('--name', type=str, default='experiment_name', 13 | help='name of the experiment. It decides where to store samples and models') 14 | parser.add_argument('--rz_interp', default='bilinear') 15 | parser.add_argument('--blur_prob', type=float, default=0) 16 | parser.add_argument('--blur_sig', default=[0, 1]) 17 | parser.add_argument('--jpg_prob', type=float, default=0) 18 | parser.add_argument('--jpg_method', default=['pil', 'cv2']) 19 | parser.add_argument('--jpg_qual', default=[90, 100]) 20 | parser.add_argument('--CropSize', type=int, 21 | default=224, help='scale images to this size') 22 | # train setting 23 | parser.add_argument('--batchsize', type=int, 24 | default=64, help='input batch size') 25 | parser.add_argument('--choices', default=[0, 0, 0, 0, 1, 0, 0, 0]) 26 | parser.add_argument('--epoch', type=int, default=30) 27 | parser.add_argument('--lr', default=1e-4) 28 | parser.add_argument('--trainsize', type=int, default=256) 29 | parser.add_argument('--load', type=str, 30 | default=None) 31 | parser.add_argument('--image_root', type=str, 32 | default='/data/chenjiaxuan/data/genImage') 33 | parser.add_argument('--save_path', type=str, 34 | default='./snapshot/sortnet/') 35 | parser.add_argument('--isPatch', action='store_false') 36 | parser.add_argument('--patch_size', default=32) 37 | parser.add_argument('--aug', action='store_false') 38 | parser.add_argument('--gpu_id', type=str, default='3') 39 | parser.add_argument('--log_name', default='log3.log', 40 | help='rename the logfile', type=str) 41 | parser.add_argument('--val_interval', default=1, 42 | type=int, help='val per interval') 43 | parser.add_argument('--val_batchsize', default=64, type=int) 44 | return parser 45 | 46 | def gather_options(self): 47 | # initialize parser with basic options 48 | if not self.initialized: 49 | parser = argparse.ArgumentParser( 50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | parser = self.initialize(parser) 52 | 53 | # get the basic options 54 | opt, _ = parser.parse_known_args() 55 | self.parser = parser 56 | 57 | return parser.parse_args() 58 | 59 | def print_options(self, opt): 60 | message = '' 61 | message += '----------------- Options ---------------\n' 62 | for k, v in sorted(vars(opt).items()): 63 | comment = '' 64 | default = self.parser.get_default(k) 65 | if v != default: 66 | comment = '\t[default: %s]' % str(default) 67 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 68 | message += '----------------- End -------------------' 69 | print(message) 70 | 71 | def parse(self, print_options=True): 72 | 73 | opt = self.gather_options() 74 | opt.isTrain = True # train or test 75 | opt.isVal = False 76 | # opt.classes = opt.classes.split(',') 77 | 78 | # # result dir, save results and opt 79 | # opt.results_dir = f"./results/{opt.detect_method}" 80 | # util.mkdir(opt.results_dir) 81 | 82 | if print_options: 83 | self.print_options(opt) 84 | 85 | # additional 86 | 87 | # opt.rz_interp = opt.rz_interp.split(',') 88 | # opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')] 89 | # opt.jpg_method = opt.jpg_method.split(',') 90 | # opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')] 91 | # if len(opt.jpg_qual) == 2: 92 | # opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1)) 93 | # elif len(opt.jpg_qual) > 2: 94 | # raise ValueError( 95 | # "Shouldn't have more than 2 values for --jpg_qual.") 96 | 97 | self.opt = opt 98 | return self.opt 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.4 2 | opencv_python==4.8.1.78 3 | opencv_python_headless==4.9.0.80 4 | Pillow==10.0.1 5 | Pillow==10.3.0 6 | scipy==1.10.1 7 | torch==1.12.1 8 | torchvision==0.13.1 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils.util import set_random_seed, poly_lr 4 | from utils.tdataloader import get_loader, get_val_loader 5 | from options import TrainOptions 6 | from networks.ssp import ssp 7 | from utils.loss import bceLoss 8 | from datetime import datetime 9 | import numpy as np 10 | """Currently assumes jpg_prob, blur_prob 0 or 1""" 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | def get_val_opt(): 16 | val_opt = TrainOptions().parse(print_options=False) 17 | val_opt.isTrain = False 18 | val_opt.isVal = True 19 | # blur 20 | val_opt.blur_prob = 0 21 | val_opt.blur_sig = [1] 22 | # jpg 23 | val_opt.jpg_prob = 0 24 | val_opt.jpg_method = ['pil'] 25 | val_opt.jpg_qual = [90] 26 | # if len(val_opt.blur_sig) == 2: 27 | # b_sig = val_opt.blur_sig 28 | # val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2] 29 | # if len(val_opt.jpg_qual) != 1: 30 | # j_qual = val_opt.jpg_qual 31 | # val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)] 32 | 33 | return val_opt 34 | 35 | 36 | def val(val_loader, model, save_path): 37 | model.eval() 38 | total_right_image = total_image = 0 39 | with torch.no_grad(): 40 | for loader in val_loader: 41 | right_ai_image = right_nature_image = 0 42 | name, val_ai_loader, ai_size, val_nature_loader, nature_size = loader['name'], loader[ 43 | 'val_ai_loader'], loader['ai_size'], loader['val_nature_loader'], loader['nature_size'] 44 | print("val on:", name) 45 | # for images, labels in tqdm(val_ai_loader, desc='val_ai'): 46 | for images, labels in val_ai_loader: 47 | images = images.cuda() 48 | labels = labels.cuda() 49 | res = model(images) 50 | res = torch.sigmoid(res).ravel() 51 | right_ai_image += (((res > 0.5) & (labels == 1)) 52 | | ((res < 0.5) & (labels == 0))).sum() 53 | 54 | print(f'ai accu: {right_ai_image/ai_size}') 55 | # for images,labels in tqdm(val_nature_loader,desc='val_nature'): 56 | for images, labels in val_nature_loader: 57 | images = images.cuda() 58 | labels = labels.cuda() 59 | res = model(images) 60 | res = torch.sigmoid(res).ravel() 61 | right_nature_image += (((res > 0.5) & (labels == 1)) 62 | | ((res < 0.5) & (labels == 0))).sum() 63 | print(f'nature accu:{right_nature_image/nature_size}') 64 | accu = (right_ai_image + right_nature_image) / \ 65 | (ai_size + nature_size) 66 | total_right_image += right_ai_image + right_nature_image 67 | total_image += ai_size + nature_size 68 | print(f'val on:{name}, Accuracy:{accu}') 69 | total_accu = total_right_image / total_image 70 | print(f'Accuracy:{total_accu}') 71 | 72 | 73 | if __name__ == '__main__': 74 | set_random_seed() 75 | # train and val options 76 | opt = TrainOptions().parse() 77 | val_opt = get_val_opt() 78 | 79 | # load data 80 | print('load data...') 81 | 82 | val_loader = get_val_loader(val_opt) 83 | 84 | # cuda config 85 | # set the device for training 86 | if opt.gpu_id == '0': 87 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 88 | print('USE GPU 0') 89 | elif opt.gpu_id == '1': 90 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 91 | print('USE GPU 1') 92 | elif opt.gpu_id == '2': 93 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 94 | print('USE GPU 2') 95 | elif opt.gpu_id == '3': 96 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 97 | print('USE GPU 3') 98 | 99 | # load model 100 | model = ssp().cuda() 101 | if opt.load is not None: 102 | model.load_state_dict(torch.load(opt.load)) 103 | print('load model from', opt.load) 104 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 105 | save_path = opt.save_path 106 | if not os.path.exists(save_path): 107 | os.makedirs(save_path) 108 | print("Start train") 109 | val(val_loader, model, save_path) 110 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py --blur_prob=0 --jpg_prob=0 --val_batchsize=64 --patchsize=32 --load='snapshots/xxx' -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils.util import set_random_seed, poly_lr 4 | from utils.tdataloader import get_loader, get_val_loader 5 | from options import TrainOptions 6 | from networks.ssp import ssp 7 | from utils.loss import bceLoss 8 | from datetime import datetime 9 | import numpy as np 10 | """Currently assumes jpg_prob, blur_prob 0 or 1""" 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | def get_val_opt(): 16 | val_opt = TrainOptions().parse(print_options=False) 17 | val_opt.isTrain = False 18 | val_opt.isVal = True 19 | # blur 20 | val_opt.blur_prob = 0 21 | val_opt.blur_sig = [1] 22 | # jpg 23 | val_opt.jpg_prob = 0 24 | val_opt.jpg_method = ['pil'] 25 | val_opt.jpg_qual = [90] 26 | # if len(val_opt.blur_sig) == 2: 27 | # b_sig = val_opt.blur_sig 28 | # val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2] 29 | # if len(val_opt.jpg_qual) != 1: 30 | # j_qual = val_opt.jpg_qual 31 | # val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)] 32 | 33 | return val_opt 34 | 35 | 36 | def train(train_loader, model, optimizer, epoch, save_path): 37 | model.train() 38 | global step 39 | epoch_step = 0 40 | loss_all = 0 41 | try: 42 | for i, (images, labels) in enumerate(train_loader, start=1): 43 | optimizer.zero_grad() 44 | images = images.cuda() 45 | preds = model(images).ravel() 46 | labels = labels.cuda() 47 | loss1 = bceLoss() 48 | loss = loss1(preds, labels) 49 | loss.backward() 50 | optimizer.step() 51 | step += 1 52 | epoch_step += 1 53 | loss_all += loss.data 54 | if i % 200 == 0 or i == total_step or i == 1: 55 | print( 56 | f'{datetime.now()} Epoch [{epoch:03d}/{opt.epoch:03d}], Step [{i:04d}/{total_step:04d}], Total_loss: {loss.data:.4f}') 57 | loss_all /= epoch_step 58 | if epoch % 50 == 0: 59 | torch.save(model.state_dict(), save_path + 60 | f'Net_epoch_{epoch}.pth') 61 | 62 | except KeyboardInterrupt: 63 | print('Keyboard Interrupt: save model and exit.') 64 | 65 | 66 | def val(val_loader, model, epoch, save_path): 67 | model.eval() 68 | global best_epoch, best_accu 69 | total_right_image = total_image = 0 70 | with torch.no_grad(): 71 | for loader in val_loader: 72 | right_ai_image = right_nature_image = 0 73 | name, val_ai_loader, ai_size, val_nature_loader, nature_size = loader['name'], loader[ 74 | 'val_ai_loader'], loader['ai_size'], loader['val_nature_loader'], loader['nature_size'] 75 | print("val on:", name) 76 | # for images, labels in tqdm(val_ai_loader, desc='val_ai'): 77 | for images, labels in val_ai_loader: 78 | images = images.cuda() 79 | labels = labels.cuda() 80 | res = model(images) 81 | res = torch.sigmoid(res).ravel() 82 | right_ai_image += (((res > 0.5) & (labels == 1)) 83 | | ((res < 0.5) & (labels == 0))).sum() 84 | 85 | print(f'ai accu: {right_ai_image/ai_size}') 86 | # for images,labels in tqdm(val_nature_loader,desc='val_nature'): 87 | for images, labels in val_nature_loader: 88 | images = images.cuda() 89 | labels = labels.cuda() 90 | res = model(images) 91 | res = torch.sigmoid(res).ravel() 92 | right_nature_image += (((res > 0.5) & (labels == 1)) 93 | | ((res < 0.5) & (labels == 0))).sum() 94 | print(f'nature accu:{right_nature_image/nature_size}') 95 | accu = (right_ai_image + right_nature_image) / \ 96 | (ai_size + nature_size) 97 | total_right_image += right_ai_image + right_nature_image 98 | total_image += ai_size + nature_size 99 | print(f'val on:{name}, Epoch:{epoch}, Accuracy:{accu}') 100 | total_accu = total_right_image / total_image 101 | if epoch == 1: 102 | best_accu = total_accu 103 | best_epoch = 1 104 | torch.save(model.state_dict(), save_path + 105 | 'Net_epoch_best.pth') 106 | print(f'Save state_dict successfully! Best epoch:{epoch}.') 107 | else: 108 | if total_accu > best_accu: 109 | best_accu = total_accu 110 | best_epoch = epoch 111 | torch.save(model.state_dict(), save_path + 112 | 'Net_epoch_best.pth') 113 | print(f'Save state_dict successfully! Best epoch:{epoch}.') 114 | print( 115 | f'Epoch:{epoch},Accuracy:{total_accu}, bestEpoch:{best_epoch}, bestAccu:{best_accu}') 116 | 117 | 118 | if __name__ == '__main__': 119 | set_random_seed() 120 | # train and val options 121 | opt = TrainOptions().parse() 122 | val_opt = get_val_opt() 123 | 124 | # load data 125 | print('load data...') 126 | train_loader = get_loader(opt) 127 | total_step = len(train_loader) 128 | val_loader = get_val_loader(val_opt) 129 | 130 | # cuda config 131 | # set the device for training 132 | if opt.gpu_id == '0': 133 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 134 | print('USE GPU 0') 135 | elif opt.gpu_id == '1': 136 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 137 | print('USE GPU 1') 138 | elif opt.gpu_id == '2': 139 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 140 | print('USE GPU 2') 141 | elif opt.gpu_id == '3': 142 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 143 | print('USE GPU 3') 144 | 145 | # load model 146 | model = ssp().cuda() 147 | if opt.load is not None: 148 | model.load_state_dict(torch.load(opt.load)) 149 | print('load model from', opt.load) 150 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 151 | save_path = opt.save_path 152 | if not os.path.exists(save_path): 153 | os.makedirs(save_path) 154 | 155 | step = 0 156 | best_epoch = 0 157 | best_accu = 0 158 | print("Start train") 159 | for epoch in range(1, opt.epoch + 1): 160 | cur_lr = poly_lr(optimizer, opt.lr, epoch, opt.epoch) 161 | train(train_loader, model, optimizer, epoch, save_path) 162 | val(val_loader, model, epoch, save_path) 163 | -------------------------------------------------------------------------------- /train_val.sh: -------------------------------------------------------------------------------- 1 | python train_val.py --blur_prob=0 --jpg_prob=0 --batchsize=64 --epoch=30 --lr=1e-4 --patchsize=32 -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def bceLoss(): 7 | return nn.BCEWithLogitsLoss() 8 | 9 | 10 | def crossEntropyLoss(): 11 | return nn.CrossEntropyLoss() 12 | def mseLoss(): 13 | return nn.MSELoss() -------------------------------------------------------------------------------- /utils/patch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from torchvision.transforms import transforms 4 | from PIL import Image 5 | 6 | 7 | def compute(patch): 8 | weight, height = patch.size 9 | m = weight 10 | res = 0 11 | patch = np.array(patch).astype(np.int64) 12 | diff_horizontal = np.sum(np.abs(patch[:, :-1, :] - patch[:, 1:, :])) 13 | diff_vertical = np.sum(np.abs(patch[:-1, :, :] - patch[1:, :, :])) 14 | diff_diagonal = np.sum(np.abs(patch[:-1, :-1, :] - patch[1:, 1:, :])) 15 | diff_diagonal += np.sum(np.abs(patch[1:, :-1, :] - patch[:-1, 1:, :])) 16 | res = diff_horizontal + diff_vertical + diff_diagonal 17 | return res.sum() 18 | 19 | 20 | def patch_img(img, patch_size, height): 21 | img_width, img_height = img.size 22 | num_patch = (height // patch_size) * (height // patch_size) 23 | patch_list = [] 24 | min_len = min(img_height, img_width) 25 | rz = transforms.Resize((height, height)) 26 | if min_len < patch_size: 27 | img = rz(img) 28 | rp = transforms.RandomCrop(patch_size) 29 | for i in range(num_patch): 30 | patch_list.append(rp(img)) 31 | patch_list.sort(key=lambda x: compute(x), reverse=False) 32 | new_img = patch_list[0] 33 | 34 | return new_img 35 | -------------------------------------------------------------------------------- /utils/tdataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | from torchvision import transforms 4 | import os 5 | from utils.patch import patch_img 6 | from PIL import Image 7 | import numpy as np 8 | import cv2 9 | import random as rd 10 | from random import random, choice 11 | from scipy.ndimage.filters import gaussian_filter 12 | from io import BytesIO 13 | mp = {0: 'imagenet_ai_0508_adm', 1: 'imagenet_ai_0419_biggan', 2: 'imagenet_glide', 3: 'imagenet_midjourney', 14 | 4: 'imagenet_ai_0419_sdv4', 5: 'imagenet_ai_0424_sdv5', 6: 'imagenet_ai_0419_vqdm', 7: 'imagenet_ai_0424_wukong', 15 | 8: 'imagenet_DALLE2' 16 | } 17 | 18 | 19 | def sample_continuous(s): 20 | if len(s) == 1: 21 | return s[0] 22 | if len(s) == 2: 23 | rg = s[1] - s[0] 24 | return random() * rg + s[0] 25 | raise ValueError("Length of iterable s should be 1 or 2.") 26 | 27 | 28 | def sample_discrete(s): 29 | if len(s) == 1: 30 | return s[0] 31 | return choice(s) 32 | 33 | 34 | def sample_randint(s): 35 | if len(s) == 1: 36 | return s[0] 37 | return rd.randint(s[0], s[1]) 38 | 39 | 40 | def gaussian_blur_gray(img, sigma): 41 | if len(img.shape) == 3: 42 | img_blur = np.zeros_like(img) 43 | for i in range(img.shape[2]): 44 | img_blur[:, :, i] = gaussian_filter(img[:, :, i], sigma=sigma) 45 | else: 46 | img_blur = gaussian_filter(img, sigma=sigma) 47 | return img_blur 48 | 49 | 50 | def gaussian_blur(img, sigma): 51 | gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma) 52 | gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma) 53 | gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma) 54 | 55 | 56 | def cv2_jpg(img, compress_val): 57 | img_cv2 = img[:, :, ::-1] 58 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] 59 | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) 60 | decimg = cv2.imdecode(encimg, 1) 61 | return decimg[:, :, ::-1] 62 | 63 | 64 | def pil_jpg(img, compress_val): 65 | out = BytesIO() 66 | img = Image.fromarray(img) 67 | img.save(out, format='jpeg', quality=compress_val) 68 | img = Image.open(out) 69 | # load from memory before ByteIO closes 70 | img = np.array(img) 71 | out.close() 72 | return img 73 | 74 | 75 | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} 76 | 77 | 78 | def jpeg_from_key(img, compress_val, key): 79 | method = jpeg_dict[key] 80 | return method(img, compress_val) 81 | 82 | 83 | def data_augment(img, opt): 84 | img = np.array(img) 85 | 86 | if random() < opt.blur_prob: 87 | sig = sample_continuous(opt.blur_sig) 88 | gaussian_blur(img, sig) 89 | 90 | if random() < opt.jpg_prob: 91 | method = sample_discrete(opt.jpg_method) 92 | qual = sample_randint(opt.jpg_qual) 93 | img = jpeg_from_key(img, qual, method) 94 | 95 | return Image.fromarray(img) 96 | 97 | 98 | def processing(img, opt): 99 | if opt.aug: 100 | aug = transforms.Lambda( 101 | lambda img: data_augment(img, opt) 102 | ) 103 | else: 104 | aug = transforms.Lambda( 105 | lambda img: img 106 | ) 107 | 108 | if opt.isPatch: 109 | patch_func = transforms.Lambda( 110 | lambda img: patch_img(img, opt.patch_size, opt.trainsize)) 111 | else: 112 | patch_func = transforms.Resize((256, 256)) 113 | 114 | trans = transforms.Compose([ 115 | aug, 116 | patch_func, 117 | transforms.ToTensor(), 118 | transforms.Normalize([0.485, 0.456, 0.406], 119 | [0.229, 0.224, 0.225]), 120 | ]) 121 | 122 | return trans(img) 123 | 124 | 125 | class genImageTrainDataset(Dataset): 126 | def __init__(self, image_root, image_dir, opt): 127 | super().__init__() 128 | self.opt = opt 129 | self.root = os.path.join(image_root, image_dir, "train") 130 | self.nature_path = os.path.join(self.root, "nature") 131 | self.nature_list = [os.path.join(self.nature_path, f) 132 | for f in os.listdir(self.nature_path)] 133 | self.nature_size = len(self.nature_list) 134 | self.ai_path = os.path.join(self.root, "ai") 135 | self.ai_list = [os.path.join(self.ai_path, f) 136 | for f in os.listdir(self.ai_path)] 137 | self.ai_size = len(self.ai_list) 138 | self.images = self.nature_list + self.ai_list 139 | self.labels = torch.cat( 140 | (torch.ones(self.nature_size), torch.zeros(self.ai_size))) 141 | 142 | def rgb_loader(self, path): 143 | with open(path, 'rb') as f: 144 | img = Image.open(f) 145 | return img.convert('RGB') 146 | 147 | def __getitem__(self, index): 148 | try: 149 | image = self.rgb_loader(self.images[index]) 150 | label = self.labels[index] 151 | except: 152 | new_index = index - 1 153 | image = self.rgb_loader( 154 | self.images[max(0, new_index)]) 155 | label = self.labels[max(0, new_index)] 156 | image = processing(image, self.opt) 157 | return image, label 158 | 159 | def __len__(self): 160 | return self.nature_size + self.ai_size 161 | 162 | 163 | class genImageValDataset(Dataset): 164 | def __init__(self, image_root, image_dir, is_real, opt): 165 | super().__init__() 166 | self.opt = opt 167 | self.root = os.path.join(image_root, image_dir, "val") 168 | if is_real: 169 | self.img_path = os.path.join(self.root, 'nature') 170 | self.img_list = [os.path.join(self.img_path, f) 171 | for f in os.listdir(self.img_path)] 172 | self.img_len = len(self.img_list) 173 | self.labels = torch.ones(self.img_len) 174 | else: 175 | self.img_path = os.path.join(self.root, 'ai') 176 | self.img_list = [os.path.join(self.img_path, f) 177 | for f in os.listdir(self.img_path)] 178 | self.img_len = len(self.img_list) 179 | self.labels = torch.zeros(self.img_len) 180 | 181 | def rgb_loader(self, path): 182 | with open(path, 'rb') as f: 183 | img = Image.open(f) 184 | return img.convert('RGB') 185 | 186 | def __getitem__(self, index): 187 | image = self.rgb_loader(self.img_list[index]) 188 | label = self.labels[index] 189 | image = processing(image, self.opt) 190 | return image, label 191 | 192 | def __len__(self): 193 | return self.img_len 194 | 195 | 196 | class genImageTestDataset(Dataset): 197 | def __init__(self, image_root, image_dir, opt): 198 | super().__init__() 199 | self.opt = opt 200 | self.root = os.path.join(image_root, image_dir, "val") 201 | self.nature_path = os.path.join(self.root, "nature") 202 | self.nature_list = [os.path.join(self.nature_path, f) 203 | for f in os.listdir(self.nature_path)] 204 | self.nature_size = len(self.nature_list) 205 | self.ai_path = os.path.join(self.root, "ai") 206 | self.ai_list = [os.path.join(self.ai_path, f) 207 | for f in os.listdir(self.ai_path)] 208 | self.ai_size = len(self.ai_list) 209 | self.images = self.nature_list + self.ai_list 210 | self.labels = torch.cat( 211 | (torch.ones(self.nature_size), torch.zeros(self.ai_size))) 212 | 213 | def rgb_loader(self, path): 214 | with open(path, 'rb') as f: 215 | img = Image.open(f) 216 | return img.convert('RGB') 217 | 218 | def __getitem__(self, index): 219 | try: 220 | image = self.rgb_loader(self.images[index]) 221 | label = self.labels[index] 222 | except: 223 | new_index = index - 1 224 | image = self.rgb_loader( 225 | self.images[max(0, new_index)]) 226 | label = self.labels[max(0, new_index)] 227 | image = processing(image, self.opt) 228 | return image, label, self.images[index] 229 | 230 | def __len__(self): 231 | return self.nature_size + self.ai_size 232 | 233 | 234 | def get_single_loader(opt, image_dir, is_real): 235 | val_dataset = genImageValDataset( 236 | opt.image_root, image_dir=image_dir, is_real=is_real, opt=opt) 237 | val_loader = DataLoader(val_dataset, batch_size=opt.val_batchsize, 238 | shuffle=False, num_workers=4, pin_memory=True) 239 | return val_loader, len(val_dataset) 240 | 241 | 242 | def get_val_loader(opt): 243 | choices = opt.choices 244 | loader = [] 245 | for i, choice in enumerate(choices): 246 | datainfo = dict() 247 | if choice == 0 or choice == 1: 248 | print("val on:", mp[i]) 249 | datainfo['name'] = mp[i] 250 | datainfo['val_ai_loader'], datainfo['ai_size'] = get_single_loader( 251 | opt, datainfo['name'], is_real=False) 252 | datainfo['val_nature_loader'], datainfo['nature_size'] = get_single_loader( 253 | opt, datainfo['name'], is_real=True) 254 | loader.append(datainfo) 255 | return loader 256 | 257 | 258 | def get_loader(opt): 259 | choices = opt.choices 260 | image_root = opt.image_root 261 | 262 | datasets = [] 263 | if choices[0] == 1: 264 | adm_dataset = genImageTrainDataset( 265 | image_root, "imagenet_ai_0508_adm", opt=opt) 266 | datasets.append(adm_dataset) 267 | print("train on: imagenet_ai_0508_adm") 268 | if choices[1] == 1: 269 | biggan_dataset = genImageTrainDataset( 270 | image_root, "imagenet_ai_0419_biggan", opt=opt) 271 | datasets.append(biggan_dataset) 272 | print("train on: imagenet_ai_0419_biggan") 273 | if choices[2] == 1: 274 | glide_dataset = genImageTrainDataset( 275 | image_root, "imagenet_glide", opt=opt) 276 | datasets.append(glide_dataset) 277 | print("train on: imagenet_glide") 278 | if choices[3] == 1: 279 | midjourney_dataset = genImageTrainDataset( 280 | image_root, "imagenet_midjourney", opt=opt) 281 | datasets.append(midjourney_dataset) 282 | print("train on: imagenet_midjourney") 283 | if choices[4] == 1: 284 | sdv14_dataset = genImageTrainDataset( 285 | image_root, "imagenet_ai_0419_sdv4", opt=opt) 286 | datasets.append(sdv14_dataset) 287 | print("train on: imagenet_ai_0419_sdv4") 288 | if choices[5] == 1: 289 | sdv15_dataset = genImageTrainDataset( 290 | image_root, "imagenet_ai_0424_sdv5", opt=opt) 291 | datasets.append(sdv15_dataset) 292 | print("train on: imagenet_ai_0424_sdv5") 293 | if choices[6] == 1: 294 | vqdm_dataset = genImageTrainDataset( 295 | image_root, "imagenet_ai_0419_vqdm", opt=opt) 296 | datasets.append(vqdm_dataset) 297 | print("train on: imagenet_ai_0419_vqdm") 298 | if choices[7] == 1: 299 | wukong_dataset = genImageTrainDataset( 300 | image_root, "imagenet_ai_0424_wukong", opt=opt) 301 | datasets.append(wukong_dataset) 302 | print("train on: imagenet_ai_0424_wukong") 303 | 304 | train_dataset = torch.utils.data.ConcatDataset(datasets) 305 | train_loader = DataLoader(train_dataset, batch_size=opt.batchsize, 306 | shuffle=True, num_workers=4, pin_memory=True) 307 | return train_loader 308 | 309 | 310 | def get_test_loader(opt): 311 | choices = opt.choices 312 | image_root = opt.image_root 313 | datasets = [] 314 | if choices[0] == 2: 315 | adm_dataset = genImageTestDataset( 316 | image_root, "imagenet_ai_0508_adm", opt=opt) 317 | datasets.append(adm_dataset) 318 | print("test on: imagenet_ai_0508_adm") 319 | if choices[1] == 2: 320 | biggan_dataset = genImageTestDataset( 321 | image_root, "imagenet_ai_0419_biggan", opt=opt) 322 | datasets.append(biggan_dataset) 323 | print("test on: imagenet_ai_0419_biggan") 324 | if choices[2] == 2: 325 | glide_dataset = genImageTestDataset( 326 | image_root, "imagenet_glide", opt=opt) 327 | datasets.append(glide_dataset) 328 | print("test on: imagenet_glide") 329 | if choices[3] == 2: 330 | midjourney_dataset = genImageTestDataset( 331 | image_root, "imagenet_midjourney", opt=opt) 332 | datasets.append(midjourney_dataset) 333 | print("test on: imagenet_midjourney") 334 | if choices[4] == 2: 335 | sdv14_dataset = genImageTestDataset( 336 | image_root, "imagenet_ai_0419_sdv4", opt=opt) 337 | datasets.append(sdv14_dataset) 338 | print("test on: imagenet_ai_0419_sdv4") 339 | if choices[5] == 2: 340 | sdv15_dataset = genImageTestDataset( 341 | image_root, "imagenet_ai_0424_sdv5", opt=opt) 342 | datasets.append(sdv15_dataset) 343 | print("test on: imagenet_ai_0424_sdv5") 344 | if choices[6] == 2: 345 | vqdm_dataset = genImageTestDataset( 346 | image_root, "imagenet_ai_0419_vqdm", opt=opt) 347 | datasets.append(vqdm_dataset) 348 | print("test on: imagenet_ai_0419_vqdm") 349 | if choices[7] == 2: 350 | wukong_dataset = genImageTestDataset( 351 | image_root, "imagenet_ai_0424_wukong", opt=opt) 352 | datasets.append(wukong_dataset) 353 | print("test on: imagenet_ai_0424_wukong") 354 | 355 | test_dataset = torch.utils.data.ConcatDataset(datasets) 356 | test_loader = DataLoader(test_dataset, batch_size=opt.batchsize, 357 | shuffle=True, num_workers=4, pin_memory=True) 358 | return test_loader 359 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | 6 | def poly_lr(optimizer, init_lr, curr_iter, max_iter, power=0.9): 7 | lr = init_lr * (1 - float(curr_iter) / max_iter) ** power 8 | for param_group in optimizer.param_groups: 9 | param_group['lr'] = lr 10 | cur_lr = lr 11 | return cur_lr 12 | 13 | 14 | def clip_gradient(optimizer, grad_clip): 15 | """ 16 | For calibrating misalignment gradient via cliping gradient technique 17 | :param optimizer: 18 | :param grad_clip: 19 | :return: 20 | """ 21 | for group in optimizer.param_groups: 22 | for param in group['params']: 23 | if param.grad is not None: 24 | param.grad.data.clamp_(-grad_clip, grad_clip) 25 | 26 | 27 | def set_random_seed(seed=42): 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | np.random.seed(seed) 33 | random.seed(seed) 34 | --------------------------------------------------------------------------------