├── assets └── arial.ttf ├── README.md └── region_division.py /assets/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangzjn/GPT-4V-AD/HEAD/assets/arial.ttf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT-4V-AD 2 | Code for "Exploring Grounding Potential of VQA-oriented GPT-4V for Zero-shot Anomaly Detection" 3 | 4 | # Generate Region Division 5 | ```angular2html 6 | python region_division.py --device cuda --dataset_name mvtec 7 | ``` 8 | -------------------------------------------------------------------------------- /region_division.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import copy 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import cv2 8 | from PIL import Image, ImageDraw, ImageFont 9 | import torch 10 | 11 | from skimage.segmentation import slic, find_boundaries 12 | from scipy.ndimage import binary_dilation 13 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | class GPT4V(object): 19 | 20 | def __init__(self, cfg): 21 | self.cfg = cfg 22 | 23 | if 'sam' in self.cfg.region_division_methods: 24 | self.sam = sam_model_registry['vit_h'](checkpoint='pretrain/sam_vit_h_4b8939.pth') 25 | self.sam.to(device=self.cfg.device) 26 | self.mask_generator = SamAutomaticMaskGenerator(self.sam) 27 | 28 | def region_division(self): 29 | image_files = [] 30 | if self.cfg.dataset_name in ['mvtec']: 31 | root = self.cfg.dataset_name 32 | image_files = glob.glob(f'{root}/*/test/*/???.png') 33 | # image_files = [image_file.replace(f'{root}/', '') for image_file in image_files] 34 | elif self.cfg.dataset_name in ['visa']: 35 | root = self.cfg.dataset_name 36 | CLSNAMES = [ 37 | 'pcb1', 'pcb2', 'pcb3', 'pcb4', 38 | 'macaroni1', 'macaroni2', 'capsules', 'candle', 39 | 'cashew', 'chewinggum', 'fryum', 'pipe_fryum', 40 | ] 41 | csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) 42 | columns = csv_data.columns # [object, split, label, image, mask] 43 | test_data = csv_data[csv_data[columns[1]] == 'test'] 44 | for cls_name in CLSNAMES: 45 | cls_data = test_data[test_data[columns[0]] == cls_name] 46 | cls_data.index = list(range(len(cls_data))) 47 | for idx in range(cls_data.shape[0]): 48 | data = cls_data.loc[idx] 49 | image_files.append(data[3]) 50 | image_files = [f'{root}/{image_file}' for image_file in image_files] 51 | if len(image_files) == 0: 52 | return -1 53 | 54 | image_files.sort() 55 | for idx1, image_file in enumerate(image_files): 56 | self.img_size = self.cfg.img_size 57 | self.div_num = self.cfg.div_num 58 | self.div_size = self.cfg.img_size // self.cfg.div_num 59 | self.edge_pixel = self.cfg.edge_pixel 60 | img = cv2.imread(image_file) 61 | img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size)) 62 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 63 | H, W, _ = img.shape 64 | 65 | if 'grid' in self.cfg.region_division_methods: 66 | masks = [] 67 | for i in range(self.div_num): 68 | for j in range(self.div_num): 69 | mask = np.zeros((self.img_size, self.img_size), dtype=np.bool) 70 | x1, x2 = j * self.div_size, (j + 1) * self.div_size 71 | y1, y2 = i * self.div_size, (i + 1) * self.div_size 72 | mask[y1:y2, x1:x2] = True 73 | masks.append(mask) 74 | self.sovle_masks(img, image_file, masks, 'grid') 75 | print(f'{self.cfg.dataset_name} --> {idx1 + 1}/{len(image_files)} {image_file} for grid') 76 | if 'superpixel' in self.cfg.region_division_methods: 77 | regions = slic(img, n_segments=60, compactness=20) 78 | masks = [] 79 | for label in range(regions.max() + 1): 80 | mask = (regions == label) 81 | masks.append(mask) 82 | self.sovle_masks(img, image_file, masks, 'superpixel') 83 | print(f'{self.cfg.dataset_name} --> {idx1 + 1}/{len(image_files)} {image_file} for superpixel') 84 | if 'sam' in self.cfg.region_division_methods: 85 | masks = self.mask_generator.generate(img) 86 | masks = [mask['segmentation'] for mask in masks] 87 | self.sovle_masks(img, image_file, masks, 'sam') 88 | print(f'{self.cfg.dataset_name} --> {idx1 + 1}/{len(image_files)} {image_file} for sam') 89 | 90 | def sovle_masks(self, img, image_file, masks, method): 91 | mask_edge = np.zeros_like(img, dtype=np.bool).astype(np.uint8) * 255 92 | mask_edge_number = np.zeros_like(img, dtype=np.uint8) 93 | img_edge = copy.deepcopy(img) 94 | img_edge_number = copy.deepcopy(img) 95 | 96 | masks = [mask for mask in masks if 600 < mask.sum() < 120000] 97 | for idx, mask in enumerate(masks): 98 | y_idx, x_idx = np.where(mask) 99 | 100 | center_y = y_idx.mean() 101 | center_x = x_idx.mean() 102 | distances_squared = (y_idx - center_y) ** 2 + (x_idx - center_x) ** 2 103 | min_index = np.argmin(distances_squared) 104 | xc, yc = x_idx[min_index], y_idx[min_index] 105 | 106 | mask1 = np.pad(mask, pad_width=1, mode='constant', constant_values=0) 107 | boundaries = find_boundaries(mask1, mode='inner') 108 | boundaries = boundaries[1:-1, 1:-1] 109 | if self.edge_pixel > 1: 110 | boundaries = binary_dilation(boundaries, iterations=self.edge_pixel - 1) 111 | mask_edge[boundaries == True] = 255 112 | 113 | text = str(idx + 1) 114 | img_pil = Image.fromarray(mask_edge_number) 115 | draw = ImageDraw.Draw(img_pil) 116 | font = ImageFont.truetype('assets/arial.ttf', 10) 117 | text_width, text_height = draw.textsize(text, font) 118 | 119 | x1_bg, y1_bg = xc - text_width // 2, yc - text_height // 2, 120 | x2_bg, y2_bg = x1_bg + text_width, y1_bg + text_height 121 | draw.rectangle([(x1_bg, y1_bg + 2), (x2_bg, y2_bg)], fill=(128, 128, 128)) 122 | draw.text((x1_bg, y1_bg), text, font=font, fill=(255, 255, 255)) 123 | mask_edge_number = np.array(img_pil) 124 | mask_edge_number = np.maximum(mask_edge_number, mask_edge) 125 | 126 | img_edge[mask_edge > 100] = mask_edge[mask_edge > 100] 127 | img_edge_number[mask_edge_number > 100] = mask_edge_number[mask_edge_number > 100] 128 | 129 | suffix = os.path.splitext(image_file)[-1] 130 | cv2.imwrite(image_file.replace(suffix, f'_{method}_mask_edge.png'), cv2.cvtColor(mask_edge, cv2.COLOR_BGR2RGB)) 131 | cv2.imwrite(image_file.replace(suffix, f'_{method}_mask_edge_number.png'), cv2.cvtColor(mask_edge_number, cv2.COLOR_BGR2RGB)) 132 | cv2.imwrite(image_file.replace(suffix, f'_{method}_img_edge.png'), cv2.cvtColor(img_edge, cv2.COLOR_BGR2RGB)) 133 | cv2.imwrite(image_file.replace(suffix, f'_{method}_img_edge_number.png'), cv2.cvtColor(img_edge_number, cv2.COLOR_BGR2RGB)) 134 | torch.save(dict(masks=masks), image_file.replace(suffix, f'_{method}_masks.pth')) 135 | 136 | image_file_tmp = image_file 137 | if 'mvtec' in image_file_tmp: # MVTec 138 | image_file_tmp = image_file_tmp.replace(image_file_tmp.split('/')[0], '') 139 | if 'visa' in image_file_tmp: # VisA 140 | image_file_tmp = image_file_tmp.replace(image_file_tmp.split('/')[0], '') 141 | if image_file_tmp.startswith('/'): 142 | image_file_tmp = image_file_tmp[1:] 143 | cls_name = image_file_tmp.split('/')[0] 144 | number_list = [str(n) for n in list(range(10))] 145 | if cls_name[-1] in number_list: 146 | cls_name = cls_name[:-1] 147 | prompt_cls = f"This is an image of {cls_name}." 148 | prompt_describe = f"The image has different region divisions, each distinguished by white edges and each with a unique numerical identifier within the region, starting from 1. Each region may exhibit anomalies of unknown types, and if any region exhibits an anomaly, the normal image is considered anomalous. Anomaly scores range from 0 to 1, with higher values indicating a higher probability of an anomaly. Please output the image anomaly score, as well as the anomaly scores for the regions with anomalies. Provide the answer in the following format: \"image anomaly score: 0.9; region 1: 0.9; region 3: 0.7.\". Ignore the region that does not contain anomalies." 149 | f = open(image_file.replace(suffix, f'_prompt_wo_cls.txt'), 'w') 150 | f.write(f'{prompt_describe}') 151 | f.close() 152 | 153 | f = open(image_file.replace(suffix, f'_prompt.txt'), 'w') 154 | f.write(f'{prompt_cls} {prompt_describe}') 155 | f.close() 156 | 157 | f = open(image_file.replace(suffix, f'_{method}_out.txt'), 'w') 158 | f.write(f'') 159 | f.close() 160 | 161 | 162 | if __name__ == '__main__': 163 | import argparse 164 | 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument('--device', type=str, default='cpu') 167 | 168 | # generate region divisions with labeled number 169 | parser.add_argument('--dataset_name', type=str, default='mvtec') 170 | # parser.add_argument('--dataset_name', type=str, default='visa') 171 | parser.add_argument('--region_division_methods', type=list, default=['superpixel']) 172 | # parser.add_argument('--region_division_methods', type=list, default=['grid', 'superpixel', 'sam']) 173 | parser.add_argument('--img_size', type=int, default=768) 174 | parser.add_argument('--div_num', type=int, default=16) 175 | parser.add_argument('--edge_pixel', type=int, default=1) 176 | 177 | cfg = parser.parse_args() 178 | runner = GPT4V(cfg) 179 | runner.region_division() 180 | --------------------------------------------------------------------------------