├── .gitignore
├── README.md
├── figures
├── pipeline.png
└── results.png
├── networks
├── resnet.py
├── srm_conv.py
└── ssp.py
├── options.py
├── requirements.txt
├── test.py
├── test.sh
├── train_val.py
├── train_val.sh
└── utils
├── loss.py
├── patch.py
├── tdataloader.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__
3 | *.pth
4 | *.log
5 | *.json
6 | *.npy
7 | snapshot
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SSP-AI-Generated-Image-Detection
2 |
3 | This is the official implementation for the following research paper:
4 |
5 | > **A Single Simple Patch is All You Need for AI-generated Image Detection** [[arxiv]](https://arxiv.org/pdf/2402.01123.pdf)
6 | >
7 | > Jiaxuan Chen, Jieteng Yao, and Li Niu
8 |
9 | Note that in the paper, we proposed Enhanced SSP (ESSP) to improve its robustness against blur and compression. Currently, we only release the code of SSP (the flowchart is shown in the following figure). The code of ESSP will be released soon.
10 |
11 |
12 |

13 |
14 |
15 | ## Environment Setup
16 | You can install the required packages by running the command:
17 | ```bash
18 | pip install -r requirements.txt
19 | ```
20 | ## Dataset
21 | The training set and testing set used in the paper can be downloaded from [GenImage](https://github.com/GenImage-Dataset/GenImage). This dataset contains data from eight generators.
22 | After downloading the dataset, you need to specify the root path in the options. The dataset can be organized as follows:
23 | ```bash
24 | GenImage/
25 | ├── imagenet_ai_0419_biggan
26 | ├── train
27 | ├── ai
28 | ├── nature
29 | ├── val
30 | ├── ai
31 | ├── nature
32 | └── imagenet_ai_0419_sdv4
33 | ├── train
34 | ├── ai
35 | ├── nature
36 | ├── val
37 | ├── ai
38 | ├── nature
39 | └── imagenet_ai_0419_vqdm
40 | ...
41 | └── imagenet_ai_0424_sdv5
42 | ...
43 | └── imagenet_ai_0424_wukong
44 | ...
45 | └── imagenet_ai_0508_adm
46 | ...
47 | └── imagenet_glide
48 | ...
49 | └── imagenet_midjourney
50 | ...
51 | ```
52 | ## Training and Validation
53 | You can simply run the following script to train and validate your model:
54 | ```bash
55 | sh train_val.sh
56 | ```
57 | ## Testing
58 | You can simply run the following script to test your model:
59 | ```bash
60 | sh test.sh
61 | ```
62 | Our pretrained models on eight dataests can be downloaded from [Baidu Cloud](https://pan.baidu.com/s/1Wk2Cqeav_wVxPMPNy-zHZQ?pwd=bcmi) (code:bcmi) or [OneDrive](https://1drv.ms/f/s!Aq2pxrmMfMvRh29sp4zHSlbJRlP7?e=C3aHEp).
63 |
64 | ## Results on GenImage
65 | The results of ResNet50 baseline and our SSP method with different training and test subsets. In each slot, the left number is the result of ResNet50 and the right number is the result of our SSP.
66 |
67 |

68 |
69 |
--------------------------------------------------------------------------------
/figures/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SSP-AI-Generated-Image-Detection/3af21c9a3085364465bd51806d7c8fb956d0726f/figures/pipeline.png
--------------------------------------------------------------------------------
/figures/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SSP-AI-Generated-Image-Detection/3af21c9a3085364465bd51806d7c8fb956d0726f/figures/results.png
--------------------------------------------------------------------------------
/networks/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 | import torch
4 | __all__ = ["ResNet", "resnet18", "resnet34",
5 | "resnet50", "resnet101", "resnet152"]
6 |
7 |
8 | model_urls = {
9 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
10 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
11 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
12 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
13 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | """3x3 convolution with padding"""
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 |
21 |
22 | def conv1x1(in_planes, out_planes, stride=1):
23 | """1x1 convolution"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
25 |
26 |
27 | class BasicBlock(nn.Module):
28 | expansion = 1
29 |
30 | def __init__(self, inplanes, planes, stride=1, downsample=None):
31 | super().__init__()
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | identity = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super().__init__()
64 | self.conv1 = conv1x1(inplanes, planes)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = conv3x3(planes, planes, stride)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = conv1x1(planes, planes * self.expansion)
69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | identity = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | identity = self.downsample(x)
90 |
91 | out += identity
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNet(nn.Module):
98 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
99 | super().__init__()
100 | self.inplanes = 64
101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
102 | stride=2, padding=3, bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.layer1 = self._make_layer(block, 64, layers[0])
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
110 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
111 | self.fc = nn.Linear(512 * block.expansion, num_classes)
112 |
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | nn.init.kaiming_normal_(
116 | m.weight, mode="fan_out", nonlinearity="relu")
117 | elif isinstance(m, nn.BatchNorm2d):
118 | nn.init.constant_(m.weight, 1)
119 | nn.init.constant_(m.bias, 0)
120 |
121 | # Zero-initialize the last BN in each residual branch,
122 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
123 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
124 | if zero_init_residual:
125 | for m in self.modules():
126 | if isinstance(m, Bottleneck):
127 | nn.init.constant_(m.bn3.weight, 0)
128 | elif isinstance(m, BasicBlock):
129 | nn.init.constant_(m.bn2.weight, 0)
130 |
131 | def _make_layer(self, block, planes, blocks, stride=1):
132 | downsample = None
133 | if stride != 1 or self.inplanes != planes * block.expansion:
134 | downsample = nn.Sequential(
135 | conv1x1(self.inplanes, planes * block.expansion, stride),
136 | nn.BatchNorm2d(planes * block.expansion),
137 | )
138 |
139 | layers = [block(self.inplanes, planes, stride, downsample)]
140 | self.inplanes = planes * block.expansion
141 | layers.extend(block(self.inplanes, planes) for _ in range(1, blocks))
142 | return nn.Sequential(*layers)
143 |
144 | def forward(self, x, *args):
145 | x = self.conv1(x)
146 | x = self.bn1(x)
147 | x = self.relu(x)
148 | x = self.maxpool(x)
149 |
150 | x = self.layer1(x)
151 | x = self.layer2(x)
152 | x = self.layer3(x)
153 | x = self.layer4(x)
154 |
155 | x = self.avgpool(x)
156 | x = x.view(x.size(0), -1)
157 | x = self.fc(x)
158 |
159 | return x
160 |
161 |
162 | def resnet18(pretrained=False, **kwargs):
163 | """Constructs a ResNet-18 model.
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
168 | if pretrained:
169 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
170 | return model
171 |
172 |
173 | def resnet34(pretrained=False, **kwargs):
174 | """Constructs a ResNet-34 model.
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | """
178 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
179 | if pretrained:
180 | model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]))
181 | return model
182 |
183 |
184 | def resnet50(pretrained=False, **kwargs):
185 | """Constructs a ResNet-50 model.
186 | Args:
187 | pretrained (bool): If True, returns a model pre-trained on ImageNet
188 | """
189 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
190 | if pretrained:
191 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
192 | return model
193 |
194 |
195 | def resnet101(pretrained=False, **kwargs):
196 | """Constructs a ResNet-101 model.
197 | Args:
198 | pretrained (bool): If True, returns a model pre-trained on ImageNet
199 | """
200 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
201 | if pretrained:
202 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
203 | return model
204 |
205 |
206 | def resnet152(pretrained=False, **kwargs):
207 | """Constructs a ResNet-152 model.
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | """
211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
212 | if pretrained:
213 | model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
214 | return model
215 |
216 |
217 | if __name__ == '__main__':
218 | net = resnet50(pretrained=True)
219 | print(net)
220 |
--------------------------------------------------------------------------------
/networks/srm_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class SRMConv2d_simple(nn.Module):
8 |
9 | def __init__(self, inc=3, learnable=False):
10 | super(SRMConv2d_simple, self).__init__()
11 | self.truc = nn.Hardtanh(-3, 3)
12 | kernel = self._build_kernel(inc) # (3,3,5,5)
13 | self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
14 | # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
15 |
16 | def forward(self, x):
17 | '''
18 | x: imgs (Batch, H, W, 3)
19 | '''
20 | out = F.conv2d(x, self.kernel, stride=1, padding=2)
21 | out = self.truc(out)
22 |
23 | return out
24 |
25 | def _build_kernel(self, inc):
26 | # filter1: KB
27 | filter1 = [[0, 0, 0, 0, 0],
28 | [0, -1, 2, -1, 0],
29 | [0, 2, -4, 2, 0],
30 | [0, -1, 2, -1, 0],
31 | [0, 0, 0, 0, 0]]
32 | # filter2:KV
33 | filter2 = [[-1, 2, -2, 2, -1],
34 | [2, -6, 8, -6, 2],
35 | [-2, 8, -12, 8, -2],
36 | [2, -6, 8, -6, 2],
37 | [-1, 2, -2, 2, -1]]
38 | # filter3:hor 2rd
39 | filter3 = [[0, 0, 0, 0, 0],
40 | [0, 0, 0, 0, 0],
41 | [0, 1, -2, 1, 0],
42 | [0, 0, 0, 0, 0],
43 | [0, 0, 0, 0, 0]]
44 |
45 | filter1 = np.asarray(filter1, dtype=float) / 4.
46 | filter2 = np.asarray(filter2, dtype=float) / 12.
47 | filter3 = np.asarray(filter3, dtype=float) / 2.
48 | # statck the filters
49 | filters = [[filter1], # , filter1, filter1],
50 | [filter2], # , filter2, filter2],
51 | [filter3]] # , filter3, filter3]] # (3,3,5,5)
52 | filters = np.array(filters)
53 | filters = np.repeat(filters, inc, axis=1)
54 | filters = torch.FloatTensor(filters) # (3,3,5,5)
55 | return filters
56 |
57 |
58 | class SRMConv2d_Separate(nn.Module):
59 |
60 | def __init__(self, inc, outc, learnable=False):
61 | super(SRMConv2d_Separate, self).__init__()
62 | self.inc = inc
63 | self.truc = nn.Hardtanh(-3, 3)
64 | kernel = self._build_kernel(inc) # (3,3,5,5)
65 | self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
66 | # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
67 | self.out_conv = nn.Sequential(
68 | nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False),
69 | nn.BatchNorm2d(outc),
70 | nn.ReLU(inplace=True)
71 | )
72 |
73 | for ly in self.out_conv.children():
74 | if isinstance(ly, nn.Conv2d):
75 | nn.init.kaiming_normal_(ly.weight, a=1)
76 |
77 | def forward(self, x):
78 | '''
79 | x: imgs (Batch, H, W, 3)
80 | '''
81 | out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc)
82 | out = self.truc(out)
83 | out = self.out_conv(out)
84 |
85 | return out
86 |
87 | def _build_kernel(self, inc):
88 | # filter1: KB
89 | filter1 = [[0, 0, 0, 0, 0],
90 | [0, -1, 2, -1, 0],
91 | [0, 2, -4, 2, 0],
92 | [0, -1, 2, -1, 0],
93 | [0, 0, 0, 0, 0]]
94 | # filter2:KV
95 | filter2 = [[-1, 2, -2, 2, -1],
96 | [2, -6, 8, -6, 2],
97 | [-2, 8, -12, 8, -2],
98 | [2, -6, 8, -6, 2],
99 | [-1, 2, -2, 2, -1]]
100 | # # filter3:hor 2rd
101 | filter3 = [[0, 0, 0, 0, 0],
102 | [0, 0, 0, 0, 0],
103 | [0, 1, -2, 1, 0],
104 | [0, 0, 0, 0, 0],
105 | [0, 0, 0, 0, 0]]
106 |
107 | filter1 = np.asarray(filter1, dtype=float) / 4.
108 | filter2 = np.asarray(filter2, dtype=float) / 12.
109 | filter3 = np.asarray(filter3, dtype=float) / 2.
110 | # statck the filters
111 | filters = [[filter1], # , filter1, filter1],
112 | [filter2], # , filter2, filter2],
113 | [filter3]] # , filter3, filter3]] # (3,3,5,5)
114 | filters = np.array(filters)
115 | # filters = np.repeat(filters, inc, axis=1)
116 | filters = np.repeat(filters, inc, axis=0)
117 | filters = torch.FloatTensor(filters) # (3,3,5,5)
118 | # print(filters.size())
119 | return filters
120 |
121 |
122 | if __name__ == '__main__':
123 | x = torch.rand(1, 3, 224, 224)
124 | srm = SRMConv2d_simple()
125 | output = srm(x)
126 | output = np.array(output)
127 | print(output.shape)
--------------------------------------------------------------------------------
/networks/ssp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from networks.resnet import resnet50
4 | from networks.srm_conv import SRMConv2d_simple
5 | import torch.nn.functional as F
6 |
7 |
8 | class ssp(nn.Module):
9 | def __init__(self, pretrain=True):
10 | super().__init__()
11 | self.srm = SRMConv2d_simple()
12 | self.disc = resnet50(pretrained=True)
13 | self.disc.fc = nn.Linear(2048, 1)
14 |
15 | def forward(self, x):
16 | x = F.interpolate(x, (256, 256), mode='bilinear')
17 | x = self.srm(x)
18 | x = self.disc(x)
19 | return x
20 |
21 |
22 | if __name__ == '__main__':
23 | model = ssp(pretrain=True)
24 | print(model)
25 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 |
5 |
6 | class TrainOptions():
7 | def __init__(self):
8 | self.initialized = False
9 |
10 | def initialize(self, parser):
11 | # data augmentation
12 | parser.add_argument('--name', type=str, default='experiment_name',
13 | help='name of the experiment. It decides where to store samples and models')
14 | parser.add_argument('--rz_interp', default='bilinear')
15 | parser.add_argument('--blur_prob', type=float, default=0)
16 | parser.add_argument('--blur_sig', default=[0, 1])
17 | parser.add_argument('--jpg_prob', type=float, default=0)
18 | parser.add_argument('--jpg_method', default=['pil', 'cv2'])
19 | parser.add_argument('--jpg_qual', default=[90, 100])
20 | parser.add_argument('--CropSize', type=int,
21 | default=224, help='scale images to this size')
22 | # train setting
23 | parser.add_argument('--batchsize', type=int,
24 | default=64, help='input batch size')
25 | parser.add_argument('--choices', default=[0, 0, 0, 0, 1, 0, 0, 0])
26 | parser.add_argument('--epoch', type=int, default=30)
27 | parser.add_argument('--lr', default=1e-4)
28 | parser.add_argument('--trainsize', type=int, default=256)
29 | parser.add_argument('--load', type=str,
30 | default=None)
31 | parser.add_argument('--image_root', type=str,
32 | default='/data/chenjiaxuan/data/genImage')
33 | parser.add_argument('--save_path', type=str,
34 | default='./snapshot/sortnet/')
35 | parser.add_argument('--isPatch', action='store_false')
36 | parser.add_argument('--patch_size', default=32)
37 | parser.add_argument('--aug', action='store_false')
38 | parser.add_argument('--gpu_id', type=str, default='3')
39 | parser.add_argument('--log_name', default='log3.log',
40 | help='rename the logfile', type=str)
41 | parser.add_argument('--val_interval', default=1,
42 | type=int, help='val per interval')
43 | parser.add_argument('--val_batchsize', default=64, type=int)
44 | return parser
45 |
46 | def gather_options(self):
47 | # initialize parser with basic options
48 | if not self.initialized:
49 | parser = argparse.ArgumentParser(
50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
51 | parser = self.initialize(parser)
52 |
53 | # get the basic options
54 | opt, _ = parser.parse_known_args()
55 | self.parser = parser
56 |
57 | return parser.parse_args()
58 |
59 | def print_options(self, opt):
60 | message = ''
61 | message += '----------------- Options ---------------\n'
62 | for k, v in sorted(vars(opt).items()):
63 | comment = ''
64 | default = self.parser.get_default(k)
65 | if v != default:
66 | comment = '\t[default: %s]' % str(default)
67 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
68 | message += '----------------- End -------------------'
69 | print(message)
70 |
71 | def parse(self, print_options=True):
72 |
73 | opt = self.gather_options()
74 | opt.isTrain = True # train or test
75 | opt.isVal = False
76 | # opt.classes = opt.classes.split(',')
77 |
78 | # # result dir, save results and opt
79 | # opt.results_dir = f"./results/{opt.detect_method}"
80 | # util.mkdir(opt.results_dir)
81 |
82 | if print_options:
83 | self.print_options(opt)
84 |
85 | # additional
86 |
87 | # opt.rz_interp = opt.rz_interp.split(',')
88 | # opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')]
89 | # opt.jpg_method = opt.jpg_method.split(',')
90 | # opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')]
91 | # if len(opt.jpg_qual) == 2:
92 | # opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1))
93 | # elif len(opt.jpg_qual) > 2:
94 | # raise ValueError(
95 | # "Shouldn't have more than 2 values for --jpg_qual.")
96 |
97 | self.opt = opt
98 | return self.opt
99 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.24.4
2 | opencv_python==4.8.1.78
3 | opencv_python_headless==4.9.0.80
4 | Pillow==10.0.1
5 | Pillow==10.3.0
6 | scipy==1.10.1
7 | torch==1.12.1
8 | torchvision==0.13.1
9 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from utils.util import set_random_seed, poly_lr
4 | from utils.tdataloader import get_loader, get_val_loader
5 | from options import TrainOptions
6 | from networks.ssp import ssp
7 | from utils.loss import bceLoss
8 | from datetime import datetime
9 | import numpy as np
10 | """Currently assumes jpg_prob, blur_prob 0 or 1"""
11 | from PIL import ImageFile
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 |
15 | def get_val_opt():
16 | val_opt = TrainOptions().parse(print_options=False)
17 | val_opt.isTrain = False
18 | val_opt.isVal = True
19 | # blur
20 | val_opt.blur_prob = 0
21 | val_opt.blur_sig = [1]
22 | # jpg
23 | val_opt.jpg_prob = 0
24 | val_opt.jpg_method = ['pil']
25 | val_opt.jpg_qual = [90]
26 | # if len(val_opt.blur_sig) == 2:
27 | # b_sig = val_opt.blur_sig
28 | # val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
29 | # if len(val_opt.jpg_qual) != 1:
30 | # j_qual = val_opt.jpg_qual
31 | # val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
32 |
33 | return val_opt
34 |
35 |
36 | def val(val_loader, model, save_path):
37 | model.eval()
38 | total_right_image = total_image = 0
39 | with torch.no_grad():
40 | for loader in val_loader:
41 | right_ai_image = right_nature_image = 0
42 | name, val_ai_loader, ai_size, val_nature_loader, nature_size = loader['name'], loader[
43 | 'val_ai_loader'], loader['ai_size'], loader['val_nature_loader'], loader['nature_size']
44 | print("val on:", name)
45 | # for images, labels in tqdm(val_ai_loader, desc='val_ai'):
46 | for images, labels in val_ai_loader:
47 | images = images.cuda()
48 | labels = labels.cuda()
49 | res = model(images)
50 | res = torch.sigmoid(res).ravel()
51 | right_ai_image += (((res > 0.5) & (labels == 1))
52 | | ((res < 0.5) & (labels == 0))).sum()
53 |
54 | print(f'ai accu: {right_ai_image/ai_size}')
55 | # for images,labels in tqdm(val_nature_loader,desc='val_nature'):
56 | for images, labels in val_nature_loader:
57 | images = images.cuda()
58 | labels = labels.cuda()
59 | res = model(images)
60 | res = torch.sigmoid(res).ravel()
61 | right_nature_image += (((res > 0.5) & (labels == 1))
62 | | ((res < 0.5) & (labels == 0))).sum()
63 | print(f'nature accu:{right_nature_image/nature_size}')
64 | accu = (right_ai_image + right_nature_image) / \
65 | (ai_size + nature_size)
66 | total_right_image += right_ai_image + right_nature_image
67 | total_image += ai_size + nature_size
68 | print(f'val on:{name}, Accuracy:{accu}')
69 | total_accu = total_right_image / total_image
70 | print(f'Accuracy:{total_accu}')
71 |
72 |
73 | if __name__ == '__main__':
74 | set_random_seed()
75 | # train and val options
76 | opt = TrainOptions().parse()
77 | val_opt = get_val_opt()
78 |
79 | # load data
80 | print('load data...')
81 |
82 | val_loader = get_val_loader(val_opt)
83 |
84 | # cuda config
85 | # set the device for training
86 | if opt.gpu_id == '0':
87 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
88 | print('USE GPU 0')
89 | elif opt.gpu_id == '1':
90 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
91 | print('USE GPU 1')
92 | elif opt.gpu_id == '2':
93 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
94 | print('USE GPU 2')
95 | elif opt.gpu_id == '3':
96 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
97 | print('USE GPU 3')
98 |
99 | # load model
100 | model = ssp().cuda()
101 | if opt.load is not None:
102 | model.load_state_dict(torch.load(opt.load))
103 | print('load model from', opt.load)
104 | optimizer = torch.optim.Adam(model.parameters(), opt.lr)
105 | save_path = opt.save_path
106 | if not os.path.exists(save_path):
107 | os.makedirs(save_path)
108 | print("Start train")
109 | val(val_loader, model, save_path)
110 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python test.py --blur_prob=0 --jpg_prob=0 --val_batchsize=64 --patchsize=32 --load='snapshots/xxx'
--------------------------------------------------------------------------------
/train_val.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from utils.util import set_random_seed, poly_lr
4 | from utils.tdataloader import get_loader, get_val_loader
5 | from options import TrainOptions
6 | from networks.ssp import ssp
7 | from utils.loss import bceLoss
8 | from datetime import datetime
9 | import numpy as np
10 | """Currently assumes jpg_prob, blur_prob 0 or 1"""
11 | from PIL import ImageFile
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 |
15 | def get_val_opt():
16 | val_opt = TrainOptions().parse(print_options=False)
17 | val_opt.isTrain = False
18 | val_opt.isVal = True
19 | # blur
20 | val_opt.blur_prob = 0
21 | val_opt.blur_sig = [1]
22 | # jpg
23 | val_opt.jpg_prob = 0
24 | val_opt.jpg_method = ['pil']
25 | val_opt.jpg_qual = [90]
26 | # if len(val_opt.blur_sig) == 2:
27 | # b_sig = val_opt.blur_sig
28 | # val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
29 | # if len(val_opt.jpg_qual) != 1:
30 | # j_qual = val_opt.jpg_qual
31 | # val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
32 |
33 | return val_opt
34 |
35 |
36 | def train(train_loader, model, optimizer, epoch, save_path):
37 | model.train()
38 | global step
39 | epoch_step = 0
40 | loss_all = 0
41 | try:
42 | for i, (images, labels) in enumerate(train_loader, start=1):
43 | optimizer.zero_grad()
44 | images = images.cuda()
45 | preds = model(images).ravel()
46 | labels = labels.cuda()
47 | loss1 = bceLoss()
48 | loss = loss1(preds, labels)
49 | loss.backward()
50 | optimizer.step()
51 | step += 1
52 | epoch_step += 1
53 | loss_all += loss.data
54 | if i % 200 == 0 or i == total_step or i == 1:
55 | print(
56 | f'{datetime.now()} Epoch [{epoch:03d}/{opt.epoch:03d}], Step [{i:04d}/{total_step:04d}], Total_loss: {loss.data:.4f}')
57 | loss_all /= epoch_step
58 | if epoch % 50 == 0:
59 | torch.save(model.state_dict(), save_path +
60 | f'Net_epoch_{epoch}.pth')
61 |
62 | except KeyboardInterrupt:
63 | print('Keyboard Interrupt: save model and exit.')
64 |
65 |
66 | def val(val_loader, model, epoch, save_path):
67 | model.eval()
68 | global best_epoch, best_accu
69 | total_right_image = total_image = 0
70 | with torch.no_grad():
71 | for loader in val_loader:
72 | right_ai_image = right_nature_image = 0
73 | name, val_ai_loader, ai_size, val_nature_loader, nature_size = loader['name'], loader[
74 | 'val_ai_loader'], loader['ai_size'], loader['val_nature_loader'], loader['nature_size']
75 | print("val on:", name)
76 | # for images, labels in tqdm(val_ai_loader, desc='val_ai'):
77 | for images, labels in val_ai_loader:
78 | images = images.cuda()
79 | labels = labels.cuda()
80 | res = model(images)
81 | res = torch.sigmoid(res).ravel()
82 | right_ai_image += (((res > 0.5) & (labels == 1))
83 | | ((res < 0.5) & (labels == 0))).sum()
84 |
85 | print(f'ai accu: {right_ai_image/ai_size}')
86 | # for images,labels in tqdm(val_nature_loader,desc='val_nature'):
87 | for images, labels in val_nature_loader:
88 | images = images.cuda()
89 | labels = labels.cuda()
90 | res = model(images)
91 | res = torch.sigmoid(res).ravel()
92 | right_nature_image += (((res > 0.5) & (labels == 1))
93 | | ((res < 0.5) & (labels == 0))).sum()
94 | print(f'nature accu:{right_nature_image/nature_size}')
95 | accu = (right_ai_image + right_nature_image) / \
96 | (ai_size + nature_size)
97 | total_right_image += right_ai_image + right_nature_image
98 | total_image += ai_size + nature_size
99 | print(f'val on:{name}, Epoch:{epoch}, Accuracy:{accu}')
100 | total_accu = total_right_image / total_image
101 | if epoch == 1:
102 | best_accu = total_accu
103 | best_epoch = 1
104 | torch.save(model.state_dict(), save_path +
105 | 'Net_epoch_best.pth')
106 | print(f'Save state_dict successfully! Best epoch:{epoch}.')
107 | else:
108 | if total_accu > best_accu:
109 | best_accu = total_accu
110 | best_epoch = epoch
111 | torch.save(model.state_dict(), save_path +
112 | 'Net_epoch_best.pth')
113 | print(f'Save state_dict successfully! Best epoch:{epoch}.')
114 | print(
115 | f'Epoch:{epoch},Accuracy:{total_accu}, bestEpoch:{best_epoch}, bestAccu:{best_accu}')
116 |
117 |
118 | if __name__ == '__main__':
119 | set_random_seed()
120 | # train and val options
121 | opt = TrainOptions().parse()
122 | val_opt = get_val_opt()
123 |
124 | # load data
125 | print('load data...')
126 | train_loader = get_loader(opt)
127 | total_step = len(train_loader)
128 | val_loader = get_val_loader(val_opt)
129 |
130 | # cuda config
131 | # set the device for training
132 | if opt.gpu_id == '0':
133 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
134 | print('USE GPU 0')
135 | elif opt.gpu_id == '1':
136 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
137 | print('USE GPU 1')
138 | elif opt.gpu_id == '2':
139 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
140 | print('USE GPU 2')
141 | elif opt.gpu_id == '3':
142 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
143 | print('USE GPU 3')
144 |
145 | # load model
146 | model = ssp().cuda()
147 | if opt.load is not None:
148 | model.load_state_dict(torch.load(opt.load))
149 | print('load model from', opt.load)
150 | optimizer = torch.optim.Adam(model.parameters(), opt.lr)
151 | save_path = opt.save_path
152 | if not os.path.exists(save_path):
153 | os.makedirs(save_path)
154 |
155 | step = 0
156 | best_epoch = 0
157 | best_accu = 0
158 | print("Start train")
159 | for epoch in range(1, opt.epoch + 1):
160 | cur_lr = poly_lr(optimizer, opt.lr, epoch, opt.epoch)
161 | train(train_loader, model, optimizer, epoch, save_path)
162 | val(val_loader, model, epoch, save_path)
163 |
--------------------------------------------------------------------------------
/train_val.sh:
--------------------------------------------------------------------------------
1 | python train_val.py --blur_prob=0 --jpg_prob=0 --batchsize=64 --epoch=30 --lr=1e-4 --patchsize=32
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def bceLoss():
7 | return nn.BCEWithLogitsLoss()
8 |
9 |
10 | def crossEntropyLoss():
11 | return nn.CrossEntropyLoss()
12 | def mseLoss():
13 | return nn.MSELoss()
--------------------------------------------------------------------------------
/utils/patch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | from torchvision.transforms import transforms
4 | from PIL import Image
5 |
6 |
7 | def compute(patch):
8 | weight, height = patch.size
9 | m = weight
10 | res = 0
11 | patch = np.array(patch).astype(np.int64)
12 | diff_horizontal = np.sum(np.abs(patch[:, :-1, :] - patch[:, 1:, :]))
13 | diff_vertical = np.sum(np.abs(patch[:-1, :, :] - patch[1:, :, :]))
14 | diff_diagonal = np.sum(np.abs(patch[:-1, :-1, :] - patch[1:, 1:, :]))
15 | diff_diagonal += np.sum(np.abs(patch[1:, :-1, :] - patch[:-1, 1:, :]))
16 | res = diff_horizontal + diff_vertical + diff_diagonal
17 | return res.sum()
18 |
19 |
20 | def patch_img(img, patch_size, height):
21 | img_width, img_height = img.size
22 | num_patch = (height // patch_size) * (height // patch_size)
23 | patch_list = []
24 | min_len = min(img_height, img_width)
25 | rz = transforms.Resize((height, height))
26 | if min_len < patch_size:
27 | img = rz(img)
28 | rp = transforms.RandomCrop(patch_size)
29 | for i in range(num_patch):
30 | patch_list.append(rp(img))
31 | patch_list.sort(key=lambda x: compute(x), reverse=False)
32 | new_img = patch_list[0]
33 |
34 | return new_img
35 |
--------------------------------------------------------------------------------
/utils/tdataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torch
3 | from torchvision import transforms
4 | import os
5 | from utils.patch import patch_img
6 | from PIL import Image
7 | import numpy as np
8 | import cv2
9 | import random as rd
10 | from random import random, choice
11 | from scipy.ndimage.filters import gaussian_filter
12 | from io import BytesIO
13 | mp = {0: 'imagenet_ai_0508_adm', 1: 'imagenet_ai_0419_biggan', 2: 'imagenet_glide', 3: 'imagenet_midjourney',
14 | 4: 'imagenet_ai_0419_sdv4', 5: 'imagenet_ai_0424_sdv5', 6: 'imagenet_ai_0419_vqdm', 7: 'imagenet_ai_0424_wukong',
15 | 8: 'imagenet_DALLE2'
16 | }
17 |
18 |
19 | def sample_continuous(s):
20 | if len(s) == 1:
21 | return s[0]
22 | if len(s) == 2:
23 | rg = s[1] - s[0]
24 | return random() * rg + s[0]
25 | raise ValueError("Length of iterable s should be 1 or 2.")
26 |
27 |
28 | def sample_discrete(s):
29 | if len(s) == 1:
30 | return s[0]
31 | return choice(s)
32 |
33 |
34 | def sample_randint(s):
35 | if len(s) == 1:
36 | return s[0]
37 | return rd.randint(s[0], s[1])
38 |
39 |
40 | def gaussian_blur_gray(img, sigma):
41 | if len(img.shape) == 3:
42 | img_blur = np.zeros_like(img)
43 | for i in range(img.shape[2]):
44 | img_blur[:, :, i] = gaussian_filter(img[:, :, i], sigma=sigma)
45 | else:
46 | img_blur = gaussian_filter(img, sigma=sigma)
47 | return img_blur
48 |
49 |
50 | def gaussian_blur(img, sigma):
51 | gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma)
52 | gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma)
53 | gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma)
54 |
55 |
56 | def cv2_jpg(img, compress_val):
57 | img_cv2 = img[:, :, ::-1]
58 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
59 | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
60 | decimg = cv2.imdecode(encimg, 1)
61 | return decimg[:, :, ::-1]
62 |
63 |
64 | def pil_jpg(img, compress_val):
65 | out = BytesIO()
66 | img = Image.fromarray(img)
67 | img.save(out, format='jpeg', quality=compress_val)
68 | img = Image.open(out)
69 | # load from memory before ByteIO closes
70 | img = np.array(img)
71 | out.close()
72 | return img
73 |
74 |
75 | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg}
76 |
77 |
78 | def jpeg_from_key(img, compress_val, key):
79 | method = jpeg_dict[key]
80 | return method(img, compress_val)
81 |
82 |
83 | def data_augment(img, opt):
84 | img = np.array(img)
85 |
86 | if random() < opt.blur_prob:
87 | sig = sample_continuous(opt.blur_sig)
88 | gaussian_blur(img, sig)
89 |
90 | if random() < opt.jpg_prob:
91 | method = sample_discrete(opt.jpg_method)
92 | qual = sample_randint(opt.jpg_qual)
93 | img = jpeg_from_key(img, qual, method)
94 |
95 | return Image.fromarray(img)
96 |
97 |
98 | def processing(img, opt):
99 | if opt.aug:
100 | aug = transforms.Lambda(
101 | lambda img: data_augment(img, opt)
102 | )
103 | else:
104 | aug = transforms.Lambda(
105 | lambda img: img
106 | )
107 |
108 | if opt.isPatch:
109 | patch_func = transforms.Lambda(
110 | lambda img: patch_img(img, opt.patch_size, opt.trainsize))
111 | else:
112 | patch_func = transforms.Resize((256, 256))
113 |
114 | trans = transforms.Compose([
115 | aug,
116 | patch_func,
117 | transforms.ToTensor(),
118 | transforms.Normalize([0.485, 0.456, 0.406],
119 | [0.229, 0.224, 0.225]),
120 | ])
121 |
122 | return trans(img)
123 |
124 |
125 | class genImageTrainDataset(Dataset):
126 | def __init__(self, image_root, image_dir, opt):
127 | super().__init__()
128 | self.opt = opt
129 | self.root = os.path.join(image_root, image_dir, "train")
130 | self.nature_path = os.path.join(self.root, "nature")
131 | self.nature_list = [os.path.join(self.nature_path, f)
132 | for f in os.listdir(self.nature_path)]
133 | self.nature_size = len(self.nature_list)
134 | self.ai_path = os.path.join(self.root, "ai")
135 | self.ai_list = [os.path.join(self.ai_path, f)
136 | for f in os.listdir(self.ai_path)]
137 | self.ai_size = len(self.ai_list)
138 | self.images = self.nature_list + self.ai_list
139 | self.labels = torch.cat(
140 | (torch.ones(self.nature_size), torch.zeros(self.ai_size)))
141 |
142 | def rgb_loader(self, path):
143 | with open(path, 'rb') as f:
144 | img = Image.open(f)
145 | return img.convert('RGB')
146 |
147 | def __getitem__(self, index):
148 | try:
149 | image = self.rgb_loader(self.images[index])
150 | label = self.labels[index]
151 | except:
152 | new_index = index - 1
153 | image = self.rgb_loader(
154 | self.images[max(0, new_index)])
155 | label = self.labels[max(0, new_index)]
156 | image = processing(image, self.opt)
157 | return image, label
158 |
159 | def __len__(self):
160 | return self.nature_size + self.ai_size
161 |
162 |
163 | class genImageValDataset(Dataset):
164 | def __init__(self, image_root, image_dir, is_real, opt):
165 | super().__init__()
166 | self.opt = opt
167 | self.root = os.path.join(image_root, image_dir, "val")
168 | if is_real:
169 | self.img_path = os.path.join(self.root, 'nature')
170 | self.img_list = [os.path.join(self.img_path, f)
171 | for f in os.listdir(self.img_path)]
172 | self.img_len = len(self.img_list)
173 | self.labels = torch.ones(self.img_len)
174 | else:
175 | self.img_path = os.path.join(self.root, 'ai')
176 | self.img_list = [os.path.join(self.img_path, f)
177 | for f in os.listdir(self.img_path)]
178 | self.img_len = len(self.img_list)
179 | self.labels = torch.zeros(self.img_len)
180 |
181 | def rgb_loader(self, path):
182 | with open(path, 'rb') as f:
183 | img = Image.open(f)
184 | return img.convert('RGB')
185 |
186 | def __getitem__(self, index):
187 | image = self.rgb_loader(self.img_list[index])
188 | label = self.labels[index]
189 | image = processing(image, self.opt)
190 | return image, label
191 |
192 | def __len__(self):
193 | return self.img_len
194 |
195 |
196 | class genImageTestDataset(Dataset):
197 | def __init__(self, image_root, image_dir, opt):
198 | super().__init__()
199 | self.opt = opt
200 | self.root = os.path.join(image_root, image_dir, "val")
201 | self.nature_path = os.path.join(self.root, "nature")
202 | self.nature_list = [os.path.join(self.nature_path, f)
203 | for f in os.listdir(self.nature_path)]
204 | self.nature_size = len(self.nature_list)
205 | self.ai_path = os.path.join(self.root, "ai")
206 | self.ai_list = [os.path.join(self.ai_path, f)
207 | for f in os.listdir(self.ai_path)]
208 | self.ai_size = len(self.ai_list)
209 | self.images = self.nature_list + self.ai_list
210 | self.labels = torch.cat(
211 | (torch.ones(self.nature_size), torch.zeros(self.ai_size)))
212 |
213 | def rgb_loader(self, path):
214 | with open(path, 'rb') as f:
215 | img = Image.open(f)
216 | return img.convert('RGB')
217 |
218 | def __getitem__(self, index):
219 | try:
220 | image = self.rgb_loader(self.images[index])
221 | label = self.labels[index]
222 | except:
223 | new_index = index - 1
224 | image = self.rgb_loader(
225 | self.images[max(0, new_index)])
226 | label = self.labels[max(0, new_index)]
227 | image = processing(image, self.opt)
228 | return image, label, self.images[index]
229 |
230 | def __len__(self):
231 | return self.nature_size + self.ai_size
232 |
233 |
234 | def get_single_loader(opt, image_dir, is_real):
235 | val_dataset = genImageValDataset(
236 | opt.image_root, image_dir=image_dir, is_real=is_real, opt=opt)
237 | val_loader = DataLoader(val_dataset, batch_size=opt.val_batchsize,
238 | shuffle=False, num_workers=4, pin_memory=True)
239 | return val_loader, len(val_dataset)
240 |
241 |
242 | def get_val_loader(opt):
243 | choices = opt.choices
244 | loader = []
245 | for i, choice in enumerate(choices):
246 | datainfo = dict()
247 | if choice == 0 or choice == 1:
248 | print("val on:", mp[i])
249 | datainfo['name'] = mp[i]
250 | datainfo['val_ai_loader'], datainfo['ai_size'] = get_single_loader(
251 | opt, datainfo['name'], is_real=False)
252 | datainfo['val_nature_loader'], datainfo['nature_size'] = get_single_loader(
253 | opt, datainfo['name'], is_real=True)
254 | loader.append(datainfo)
255 | return loader
256 |
257 |
258 | def get_loader(opt):
259 | choices = opt.choices
260 | image_root = opt.image_root
261 |
262 | datasets = []
263 | if choices[0] == 1:
264 | adm_dataset = genImageTrainDataset(
265 | image_root, "imagenet_ai_0508_adm", opt=opt)
266 | datasets.append(adm_dataset)
267 | print("train on: imagenet_ai_0508_adm")
268 | if choices[1] == 1:
269 | biggan_dataset = genImageTrainDataset(
270 | image_root, "imagenet_ai_0419_biggan", opt=opt)
271 | datasets.append(biggan_dataset)
272 | print("train on: imagenet_ai_0419_biggan")
273 | if choices[2] == 1:
274 | glide_dataset = genImageTrainDataset(
275 | image_root, "imagenet_glide", opt=opt)
276 | datasets.append(glide_dataset)
277 | print("train on: imagenet_glide")
278 | if choices[3] == 1:
279 | midjourney_dataset = genImageTrainDataset(
280 | image_root, "imagenet_midjourney", opt=opt)
281 | datasets.append(midjourney_dataset)
282 | print("train on: imagenet_midjourney")
283 | if choices[4] == 1:
284 | sdv14_dataset = genImageTrainDataset(
285 | image_root, "imagenet_ai_0419_sdv4", opt=opt)
286 | datasets.append(sdv14_dataset)
287 | print("train on: imagenet_ai_0419_sdv4")
288 | if choices[5] == 1:
289 | sdv15_dataset = genImageTrainDataset(
290 | image_root, "imagenet_ai_0424_sdv5", opt=opt)
291 | datasets.append(sdv15_dataset)
292 | print("train on: imagenet_ai_0424_sdv5")
293 | if choices[6] == 1:
294 | vqdm_dataset = genImageTrainDataset(
295 | image_root, "imagenet_ai_0419_vqdm", opt=opt)
296 | datasets.append(vqdm_dataset)
297 | print("train on: imagenet_ai_0419_vqdm")
298 | if choices[7] == 1:
299 | wukong_dataset = genImageTrainDataset(
300 | image_root, "imagenet_ai_0424_wukong", opt=opt)
301 | datasets.append(wukong_dataset)
302 | print("train on: imagenet_ai_0424_wukong")
303 |
304 | train_dataset = torch.utils.data.ConcatDataset(datasets)
305 | train_loader = DataLoader(train_dataset, batch_size=opt.batchsize,
306 | shuffle=True, num_workers=4, pin_memory=True)
307 | return train_loader
308 |
309 |
310 | def get_test_loader(opt):
311 | choices = opt.choices
312 | image_root = opt.image_root
313 | datasets = []
314 | if choices[0] == 2:
315 | adm_dataset = genImageTestDataset(
316 | image_root, "imagenet_ai_0508_adm", opt=opt)
317 | datasets.append(adm_dataset)
318 | print("test on: imagenet_ai_0508_adm")
319 | if choices[1] == 2:
320 | biggan_dataset = genImageTestDataset(
321 | image_root, "imagenet_ai_0419_biggan", opt=opt)
322 | datasets.append(biggan_dataset)
323 | print("test on: imagenet_ai_0419_biggan")
324 | if choices[2] == 2:
325 | glide_dataset = genImageTestDataset(
326 | image_root, "imagenet_glide", opt=opt)
327 | datasets.append(glide_dataset)
328 | print("test on: imagenet_glide")
329 | if choices[3] == 2:
330 | midjourney_dataset = genImageTestDataset(
331 | image_root, "imagenet_midjourney", opt=opt)
332 | datasets.append(midjourney_dataset)
333 | print("test on: imagenet_midjourney")
334 | if choices[4] == 2:
335 | sdv14_dataset = genImageTestDataset(
336 | image_root, "imagenet_ai_0419_sdv4", opt=opt)
337 | datasets.append(sdv14_dataset)
338 | print("test on: imagenet_ai_0419_sdv4")
339 | if choices[5] == 2:
340 | sdv15_dataset = genImageTestDataset(
341 | image_root, "imagenet_ai_0424_sdv5", opt=opt)
342 | datasets.append(sdv15_dataset)
343 | print("test on: imagenet_ai_0424_sdv5")
344 | if choices[6] == 2:
345 | vqdm_dataset = genImageTestDataset(
346 | image_root, "imagenet_ai_0419_vqdm", opt=opt)
347 | datasets.append(vqdm_dataset)
348 | print("test on: imagenet_ai_0419_vqdm")
349 | if choices[7] == 2:
350 | wukong_dataset = genImageTestDataset(
351 | image_root, "imagenet_ai_0424_wukong", opt=opt)
352 | datasets.append(wukong_dataset)
353 | print("test on: imagenet_ai_0424_wukong")
354 |
355 | test_dataset = torch.utils.data.ConcatDataset(datasets)
356 | test_loader = DataLoader(test_dataset, batch_size=opt.batchsize,
357 | shuffle=True, num_workers=4, pin_memory=True)
358 | return test_loader
359 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 |
5 |
6 | def poly_lr(optimizer, init_lr, curr_iter, max_iter, power=0.9):
7 | lr = init_lr * (1 - float(curr_iter) / max_iter) ** power
8 | for param_group in optimizer.param_groups:
9 | param_group['lr'] = lr
10 | cur_lr = lr
11 | return cur_lr
12 |
13 |
14 | def clip_gradient(optimizer, grad_clip):
15 | """
16 | For calibrating misalignment gradient via cliping gradient technique
17 | :param optimizer:
18 | :param grad_clip:
19 | :return:
20 | """
21 | for group in optimizer.param_groups:
22 | for param in group['params']:
23 | if param.grad is not None:
24 | param.grad.data.clamp_(-grad_clip, grad_clip)
25 |
26 |
27 | def set_random_seed(seed=42):
28 | torch.manual_seed(seed)
29 | torch.cuda.manual_seed_all(seed)
30 | torch.backends.cudnn.benchmark = False
31 | torch.backends.cudnn.deterministic = True
32 | np.random.seed(seed)
33 | random.seed(seed)
34 |
--------------------------------------------------------------------------------