├── README.md ├── baselines.py ├── checkpoints ├── finetuned_q16 │ └── prompts.pt ├── multi-headed │ ├── disturbing.pt │ ├── hateful.pt │ ├── political.pt │ ├── sexual.pt │ └── violent.pt └── q16 │ └── prompts.p ├── config.py ├── evaluate.py ├── inference.py ├── requirements.txt └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # unsafe-diffusion 2 | 3 | This repository provides the data and code for the paper *Unsafe Diffusion: On the Generation of Unsafe Images and Hateful Memes From Text-To-Image Models*, accepted in ACM CCS 2023. 4 | 5 | Paper: https://arxiv.org/pdf/2305.13873.pdf 6 | 7 | ## Unsafe Image Generation 8 | 9 | 10 | ### 1. Collecting Prompts 11 | 12 | We use three harmful prompt datasets and one harmless prompt dataset. Request the prompt datasets here: https://zenodo.org/record/8255664 13 | 14 | - 4chan prompts (harmful) 15 | - Lexica prompts (harmful) 16 | - Template prompts (harmful) 17 | - COCO prompts (harmless) 18 | 19 | ### 2. Generating Images 20 | 21 | We use four open-sourced Text-to-Image models: 22 | 23 | - Stable Diffusion: https://github.com/CompVis/stable-diffusion 24 | - Latent Diffusion: https://github.com/CompVis/latent-diffusion 25 | - DALLE-2 demo: https://github.com/lucidrains/DALLE2-pytorch 26 | - DALLE-mini: https://github.com/borisdayma/dalle-mini 27 | 28 | ### 3. Unsafe Image Classification 29 | 30 | We labeled 800 generated images. Request the image dataset here: https://zenodo.org/record/8255664 31 | 32 | **Prerequisite** 33 | 34 | ```pip install -r requirements.txt``` 35 | 36 | **Train the Multi-headed Safety Classifier** 37 | 38 | ``` 39 | python train.py 40 | --images_dir ./data/images \ 41 | --labels_dir ./data/labels.xlsx \ 42 | --output_dir ./checkpoints/multi-headed\ 43 | ``` 44 | 45 | **Evaluate the Classifier and Other Baselines** 46 | 47 | ``` 48 | python evaluate.py 49 | --images_dir ./data/images \ 50 | --labels_dir ./data/labels.xlsx \ 51 | --checkpoints_dir ./checkpoints 52 | ``` 53 | 54 | **Directly Use the Classifier to Detect Unsafe Images** 55 | 56 | ``` 57 | python inference.py 58 | --images_dir ./data/images \ 59 | --output_dir ./results 60 | ``` 61 | 62 | 63 | ## Hateful Meme Generation 64 | 65 | We employ three image editing techniques on top of Stable Diffusion: 66 | 67 | - DreamBooth: https://github.com/XavierXiao/Dreambooth-Stable-Diffusion 68 | - Textual Inversion: https://github.com/rinongal/textual_inversion 69 | - SDEdit: https://github.com/CompVis/stable-diffusion 70 | 71 | ## Reference 72 | 73 | If you find this helpful, please cite the following work: 74 | ``` 75 | @inproceedings{QSHBZZ23, 76 | author = {Yiting Qu and Xinyue Shen and Xinlei He and Michael Backes and Savvas Zannettou and Yang Zhang}, 77 | title = {{Unsafe Diffusion: On the Generation of Unsafe Images and Hateful Memes From Text-To-Image Models}}, 78 | booktitle = {{ACM SIGSAC Conference on Computer and Communications Security (CCS)}}, 79 | publisher = {ACM}, 80 | year = {2023} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /baselines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import PIL 5 | from PIL import Image 6 | import argparse, os, sys, glob 7 | import json 8 | from pathlib import Path 9 | import torch.nn as nn 10 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 11 | from transformers import AutoFeatureExtractor 12 | from sklearn import metrics 13 | import random 14 | import open_clip 15 | import logging 16 | import pickle 17 | import config 18 | 19 | device = torch.device("cuda") 20 | 21 | #========================================== Stable Diffusion built-in Classifier ===================================== 22 | 23 | def load_img(path): 24 | image = Image.open(path).convert("RGB") 25 | w, h = image.size 26 | print(f"loaded input image of size ({w}, {h}) from {path}") 27 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 28 | image = image.resize((512, 512), resample=PIL.Image.LANCZOS) 29 | image = np.array(image).astype(np.float32) / 255.0 30 | image = image[None].transpose(0, 3, 1, 2) 31 | image = torch.from_numpy(image) 32 | return 2.*image - 1. 33 | 34 | def numpy_to_pil(images): 35 | """ 36 | Convert a numpy image or a batch of images to a PIL image. 37 | """ 38 | if images.ndim == 3: 39 | images = images[None, ...] 40 | # images = (images * 255).round().astype("uint8") 41 | pil_images = [Image.fromarray(image) for image in images] 42 | return pil_images 43 | 44 | def cosine_distance(image_embeds, text_embeds): 45 | normalized_image_embeds = nn.functional.normalize(image_embeds) 46 | normalized_text_embeds = nn.functional.normalize(text_embeds) 47 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) 48 | 49 | 50 | def safety_filter_check(loader): 51 | 52 | safety_model_id = "CompVis/stable-diffusion-safety-checker" 53 | safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) 54 | safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id).to(device) 55 | safety_checker.eval() 56 | 57 | def check_safety(x_image, device): # [1, W, H, C] 58 | safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") # pixel_values: [1, 3, 224, 224] 59 | safety_checker_input = safety_checker_input.to(device) 60 | x_checked_image, has_nsfw_concept = safety_checker(clip_input=safety_checker_input.pixel_values, images=x_image) 61 | assert x_checked_image.shape[0] == len(has_nsfw_concept) 62 | return has_nsfw_concept 63 | 64 | res = [] 65 | for batch in loader: 66 | imgs, _ = batch 67 | for img in imgs: 68 | img = Image.open(img).convert("RGB") 69 | img_array = np.expand_dims(np.array(img), axis=0) 70 | has_nsfw_concept = check_safety(img_array, device)[0] 71 | res.append(int(has_nsfw_concept)) 72 | 73 | return res 74 | 75 | #========================================== Q16 Model ===================================== 76 | 77 | class SimClassifier(torch.nn.Module): 78 | def __init__(self, embeddings, model_name, pretrained): 79 | super(SimClassifier, self).__init__() 80 | self.clip_model, self.preprocess, _ = open_clip.create_model_and_transforms(model_name, pretrained) 81 | self.clip_model.to(torch.float32) 82 | 83 | self.prompts = torch.nn.Parameter(embeddings) 84 | 85 | def freeze(self): 86 | self.clip_model = self.clip_model.eval() 87 | for param in self.clip_model.parameters(): 88 | param.requires_grad = False 89 | 90 | def forward(self, x): 91 | text_feat = self.prompts / self.prompts.norm(dim=-1, keepdim=True) 92 | image_feat = self.clip_model.encode_image(x) 93 | # Pick the top 5 most similar labels for the image 94 | image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True) 95 | 96 | similarity = (100.0 * image_feat @ text_feat.T) 97 | # values, indices = similarity[0].topk(5) 98 | return similarity.squeeze() 99 | 100 | def initialize_prompts(model, text_prompts, device): 101 | text = model.preprocess(text_prompts).to(device) 102 | return model.clip_model.encode_text(text) 103 | 104 | def load_prompts(file_path, device): 105 | if file_path.endswith("p"): 106 | res = torch.FloatTensor(pickle.load(open(file_path, 'rb'))).to(device) 107 | elif file_path.endswith("pt"): 108 | res = torch.load(open(file_path, 'rb')).to(device).to(torch.float32) 109 | return res 110 | 111 | 112 | def Q16_check(loader, checkpoint): 113 | 114 | model_name, pretrained = config.model_name, config.pretrained 115 | soft_prompts = load_prompts(checkpoint, device) 116 | classifier = SimClassifier(soft_prompts, model_name, pretrained) 117 | classifier.freeze() 118 | classifier.to(device) 119 | 120 | res = [] 121 | for batch in loader: 122 | imgs, _ = batch 123 | images = [classifier.preprocess(PIL.Image.open(img_path)) for img_path in imgs] 124 | images = torch.stack(images).to(device) # [b_s, 3, 224, 224] 125 | y = classifier(images) 126 | y = torch.argmax(y, dim=1).detach().cpu().numpy() 127 | res.extend(y.tolist()) 128 | 129 | return res 130 | -------------------------------------------------------------------------------- /checkpoints/finetuned_q16/prompts.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YitingQu/unsafe-diffusion/9ad261669b9dbdbf1a6ece8ff48b2cb6fdfec586/checkpoints/finetuned_q16/prompts.pt -------------------------------------------------------------------------------- /checkpoints/multi-headed/disturbing.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YitingQu/unsafe-diffusion/9ad261669b9dbdbf1a6ece8ff48b2cb6fdfec586/checkpoints/multi-headed/disturbing.pt -------------------------------------------------------------------------------- /checkpoints/multi-headed/hateful.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YitingQu/unsafe-diffusion/9ad261669b9dbdbf1a6ece8ff48b2cb6fdfec586/checkpoints/multi-headed/hateful.pt -------------------------------------------------------------------------------- /checkpoints/multi-headed/political.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YitingQu/unsafe-diffusion/9ad261669b9dbdbf1a6ece8ff48b2cb6fdfec586/checkpoints/multi-headed/political.pt -------------------------------------------------------------------------------- /checkpoints/multi-headed/sexual.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YitingQu/unsafe-diffusion/9ad261669b9dbdbf1a6ece8ff48b2cb6fdfec586/checkpoints/multi-headed/sexual.pt -------------------------------------------------------------------------------- /checkpoints/multi-headed/violent.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YitingQu/unsafe-diffusion/9ad261669b9dbdbf1a6ece8ff48b2cb6fdfec586/checkpoints/multi-headed/violent.pt -------------------------------------------------------------------------------- /checkpoints/q16/prompts.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YitingQu/unsafe-diffusion/9ad261669b9dbdbf1a6ece8ff48b2cb6fdfec586/checkpoints/q16/prompts.p -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | model_name = "ViT-L-14" 2 | pretrained = "openai" 3 | 4 | learning_rate = 2e-4 5 | batch_size = 64 6 | epoch = 31 -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import PIL 5 | from PIL import Image 6 | import argparse, os, sys, glob 7 | import json 8 | from pathlib import Path 9 | from sklearn import metrics 10 | import random 11 | import open_clip 12 | from baselines import safety_filter_check, Q16_check 13 | import config, tqdm 14 | from train import BinaryAnnotatedDataset, MHSafetyClassifier 15 | 16 | SEED = 2022 17 | device = torch.device("cuda") 18 | torch.manual_seed(SEED) 19 | random.seed(SEED) 20 | np.random.seed(SEED) 21 | 22 | encode_labels = {"normal":0, "sexual":1, "violent":2, "disturbing":3, "hateful":4, "political": 5} 23 | unsafe_contents = list(encode_labels.keys())[1:] 24 | 25 | 26 | def multiheaded_check(loader, checkpoints): 27 | model_name, pretrained = config.model_name, config.pretrained 28 | model = MHSafetyClassifier(device, model_name, pretrained) 29 | model.freeze() 30 | res = {} 31 | with torch.no_grad(): 32 | for head in unsafe_contents: 33 | model.projection_head.load_state_dict(torch.load(f"{checkpoints}/{head}.pt")) 34 | model.projection_head.eval() 35 | 36 | res[head] = [] 37 | for batch in loader: 38 | imgs, _ = batch 39 | images = [model.preprocess(PIL.Image.open(img_path)) for img_path in imgs] 40 | images = torch.stack(images).to(device) # [b_s, 3, 224, 224] 41 | logits = model(images).squeeze() 42 | preds = (logits.detach().cpu()>0.5).to(dtype=torch.int64) 43 | res[head].extend(preds.tolist()) 44 | 45 | return res 46 | 47 | def eval(opt, detector, split="test"): 48 | 49 | q16_checkpoint = f"{opt.checkpoints_dir}/q16/prompts.p" 50 | q16_checkpoint_finetuned = f"{opt.checkpoints_dir}/finetuned_q16/prompts.pt" 51 | mh_checkpoints = f"{opt.checkpoints_dir}/multi-headed" 52 | 53 | dataset = BinaryAnnotatedDataset(images_dir=opt.images_dir, labels_dir=opt.labels_dir, split=split) 54 | loader = torch.utils.data.DataLoader(dataset, batch_size=50, drop_last=False, shuffle=False) 55 | 56 | if detector == "safety_checker": 57 | res = safety_filter_check(loader=loader) 58 | elif detector == "q16": 59 | res = Q16_check(loader=loader, checkpoint=q16_checkpoint) 60 | elif detector == "finetuned_q16": 61 | res = Q16_check(loader=loader, checkpoint=q16_checkpoint_finetuned) 62 | elif detector == "multi-headed": 63 | res = multiheaded_check(loader=loader, checkpoints=mh_checkpoints) 64 | 65 | ground_truth = [dataset.__getitem__(idx)[1] for idx in range(len(dataset))] 66 | if detector == "multi-headed": 67 | _preds = [] 68 | for head in unsafe_contents: 69 | _preds.append(res[head]) 70 | _preds = np.array(_preds) 71 | preds = np.sum(_preds, axis=0) 72 | preds = np.int16(preds>0) 73 | else: 74 | preds = res 75 | 76 | 77 | accuracy = metrics.accuracy_score(ground_truth, preds) 78 | precision = metrics.precision_score(ground_truth, preds) 79 | recall = metrics.recall_score(ground_truth, preds) 80 | f1_score = metrics.f1_score(ground_truth, preds) 81 | print("accuracy, precision, recall, f1_score") 82 | print(f"{detector}, {accuracy:.2f}, {precision:.2f}, {recall:.2f}, {f1_score:.2f}") 83 | 84 | 85 | if __name__=="__main__": 86 | 87 | parser = argparse.ArgumentParser() 88 | 89 | parser.add_argument( 90 | "--images_dir", 91 | type=str, 92 | nargs="?", 93 | default=None, 94 | help="images folder" 95 | ) 96 | parser.add_argument( 97 | "--labels_dir", 98 | type=str, 99 | nargs="?", 100 | default=None, 101 | help="the directory saved prompts" 102 | ) 103 | parser.add_argument( 104 | "--checkpoints_dir", 105 | type=str, 106 | nargs="?", 107 | help="dir to write results to", 108 | default=None 109 | ) 110 | 111 | opt = parser.parse_args() 112 | 113 | for detector in ["safety_checker", "q16", "finetuned_q16", "multi-headed"]: 114 | eval(opt, detector) 115 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import PIL 5 | from PIL import Image 6 | import argparse, os, sys, glob 7 | import json 8 | from pathlib import Path 9 | from sklearn import metrics 10 | import random 11 | import open_clip 12 | from baselines import safety_filter_check, Q16_check 13 | import config, tqdm 14 | from train import BinaryAnnotatedDataset, MHSafetyClassifier 15 | 16 | device = torch.device("cuda") 17 | torch.manual_seed(2022) 18 | random.seed(2022) 19 | np.random.seed(2022) 20 | 21 | encode_labels = {"normal":0, "sexual":1, "violent":2, "disturbing":3, "hateful":4, "political": 5} 22 | unsafe_contents = list(encode_labels.keys())[1:] 23 | 24 | 25 | class ImageDataset(torch.utils.data.Dataset): 26 | def __init__(self, images_dir): 27 | self.images_dir = images_dir 28 | self.image_locs = os.listdir(images_dir) 29 | def __getitem__(self, idx): 30 | return os.path.join(self.images_dir, self.image_locs[idx]) 31 | 32 | def __len__(self): 33 | return len(self.image_locs) 34 | 35 | def multiheaded_check(loader, checkpoints): 36 | model_name, pretrained = config.model_name, config.pretrained 37 | model = MHSafetyClassifier(device, model_name, pretrained) 38 | model.freeze() 39 | res = {} 40 | with torch.no_grad(): 41 | for head in unsafe_contents: 42 | model.projection_head.load_state_dict(torch.load(f"{checkpoints}/{head}_head.pt")) 43 | model.projection_head.eval() 44 | 45 | res[head] = [] 46 | for batch in loader: 47 | imgs = batch 48 | images = [model.preprocess(PIL.Image.open(img_path)) for img_path in imgs] 49 | images = torch.stack(images).to(device) # [b_s, 3, 224, 224] 50 | logits = model(images).squeeze() 51 | preds = (logits.detach().cpu()>0.5).to(dtype=torch.int64) 52 | res[head].extend(preds.tolist()) 53 | 54 | return res 55 | 56 | def main(opt): 57 | 58 | mh_checkpoints = "./checkpoints/multi-headed" 59 | 60 | output_dir = opt.output_dir 61 | Path(output_dir).mkdir(exist_ok=True, parents=True) 62 | 63 | dataset = ImageDataset(images_dir=opt.images_dir) 64 | loader = torch.utils.data.DataLoader(dataset, batch_size=50, drop_last=False, shuffle=False) 65 | 66 | res = multiheaded_check(loader=loader, checkpoints=mh_checkpoints) 67 | 68 | # convert to binary label > safe/unsafe 69 | _preds = [] 70 | for head in unsafe_contents: 71 | _preds.append(res[head]) 72 | _preds = np.array(_preds) 73 | preds = np.sum(_preds, axis=0) 74 | preds = np.int16(preds>0) 75 | 76 | final_result = {} 77 | for i, item in enumerate(dataset): 78 | final_result[item] = str(preds[i]) 79 | 80 | json.dump(final_result, open(f"{output_dir}/predictions.json", "w")) 81 | 82 | 83 | 84 | if __name__=="__main__": 85 | 86 | parser = argparse.ArgumentParser() 87 | 88 | parser.add_argument( 89 | "--images_dir", 90 | type=str, 91 | nargs="?", 92 | default=None, 93 | help="images folder" 94 | ) 95 | parser.add_argument( 96 | "--output_dir", 97 | type=str, 98 | nargs="?", 99 | help="dir to write results to", 100 | default=None 101 | ) 102 | 103 | opt = parser.parse_args() 104 | 105 | main(opt) 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | albumentations==0.4.3 5 | altair==4.2.0 6 | antlr4-python3-runtime==4.8 7 | async-timeout==4.0.2 8 | attrs==22.1.0 9 | backports.zoneinfo==0.2.1 10 | blinker==1.5 11 | brotlipy==0.7.0 12 | cachetools==5.2.0 13 | certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi 14 | cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work 15 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 16 | click==8.1.3 17 | -e git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1#egg=clip 18 | coloredlogs==15.0.1 19 | commonmark==0.9.1 20 | contourpy==1.0.7 21 | cryptography @ file:///tmp/build/80754af9/cryptography_1652083738073/work 22 | cycler==0.11.0 23 | decorator==5.1.1 24 | diffusers==0.2.4 25 | einops==0.3.0 26 | entrypoints==0.4 27 | et-xmlfile==1.1.0 28 | filelock==3.8.0 29 | flatbuffers==2.0.7 30 | fonttools==4.39.4 31 | frozenlist==1.3.1 32 | fsspec==2022.8.2 33 | ftfy==6.1.1 34 | future==0.18.2 35 | gensim==4.3.1 36 | gitdb==4.0.9 37 | GitPython==3.1.27 38 | google-auth==2.11.0 39 | google-auth-oauthlib==0.4.6 40 | grpcio==1.47.0 41 | huggingface-hub==0.9.1 42 | humanfriendly==10.0 43 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work 44 | imageio==2.9.0 45 | imageio-ffmpeg==0.4.2 46 | imgaug==0.2.6 47 | importlib-metadata==4.12.0 48 | importlib-resources==5.9.0 49 | invisible-watermark==0.1.5 50 | Jinja2==3.1.2 51 | joblib==1.2.0 52 | jsonschema==4.15.0 53 | kiwisolver==1.4.4 54 | kornia==0.6.0 55 | Markdown==3.4.1 56 | MarkupSafe==2.1.1 57 | matplotlib==3.7.1 58 | mkl-fft==1.3.1 59 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work 60 | mkl-service==2.4.0 61 | mpmath==1.2.1 62 | multidict==6.0.2 63 | networkx==2.8.6 64 | nltk==3.8.1 65 | numpy==1.23.2 66 | oauthlib==3.2.0 67 | omegaconf==2.1.1 68 | onnx==1.12.0 69 | onnxruntime==1.12.1 70 | open-clip-torch==2.20.0 71 | opencv-python==4.1.2.30 72 | opencv-python-headless==4.6.0.66 73 | openpyxl==3.1.2 74 | packaging==21.3 75 | pandas==1.4.4 76 | Pillow==9.2.0 77 | pkgutil-resolve-name==1.3.10 78 | protobuf==3.19.4 79 | pudb==2019.2 80 | pyarrow==9.0.0 81 | pyasn1==0.4.8 82 | pyasn1-modules==0.2.8 83 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 84 | pydeck==0.8.0b1 85 | pyDeprecate==0.3.1 86 | Pygments==2.13.0 87 | Pympler==1.0.1 88 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work 89 | pyparsing==3.0.9 90 | pyrsistent==0.18.1 91 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 92 | python-dateutil==2.8.2 93 | pytorch-lightning==1.4.2 94 | pytz==2022.2.1 95 | pytz-deprecation-shim==0.1.0.post0 96 | PyWavelets==1.3.0 97 | PyYAML==6.0 98 | regex==2022.8.17 99 | requests @ file:///opt/conda/conda-bld/requests_1657734628632/work 100 | requests-oauthlib==1.3.1 101 | rich==12.5.1 102 | rsa==4.9 103 | safetensors==0.3.1 104 | scikit-image==0.19.3 105 | scikit-learn==1.2.2 106 | scipy==1.9.1 107 | seaborn==0.12.2 108 | semver==2.13.0 109 | sentencepiece==0.1.99 110 | six @ file:///tmp/build/80754af9/six_1644875935023/work 111 | smart-open==6.3.0 112 | smmap==5.0.0 113 | streamlit==1.12.2 114 | sympy==1.11.1 115 | -e git+https://github.com/CompVis/taming-transformers.git@24268930bf1dce879235a7fddd0b2355b84d7ea6#egg=taming_transformers 116 | tensorboard==2.10.0 117 | tensorboard-data-server==0.6.1 118 | tensorboard-plugin-wit==1.8.1 119 | test-tube==0.7.5 120 | threadpoolctl==3.1.0 121 | tifffile==2022.8.12 122 | timm==0.9.2 123 | tokenizers==0.12.1 124 | toml==0.10.2 125 | toolz==0.12.0 126 | torch==1.11.0 127 | torch-fidelity==0.3.0 128 | torchmetrics==0.6.0 129 | torchvision==0.12.0 130 | tornado==6.2 131 | tqdm==4.64.0 132 | transformers==4.19.2 133 | typing-extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work 134 | tzdata==2022.2 135 | tzlocal==4.2 136 | urllib3 @ file:///tmp/abs_5dhwnz6atv/croots/recipe/urllib3_1659110457909/work 137 | urwid==2.1.2 138 | validators==0.20.0 139 | watchdog==2.1.9 140 | wcwidth==0.2.5 141 | Werkzeug==2.2.2 142 | wordcloud==1.9.2 143 | yarl==1.8.1 144 | zipp==3.8.1 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import PIL 5 | import argparse, os, sys, glob 6 | import json 7 | from pathlib import Path 8 | import torch.nn as nn 9 | from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler 10 | from torch.nn import functional as F 11 | from sklearn import metrics 12 | import random 13 | import open_clip 14 | import logging 15 | import config, tqdm 16 | from collections import Counter 17 | import shutil 18 | from sklearn import model_selection 19 | 20 | device = torch.device("cuda") 21 | SEED = 2022 22 | torch.manual_seed(SEED) 23 | random.seed(SEED) 24 | np.random.seed(SEED) 25 | 26 | encode_labels = {"normal":0, "sexual":1, "violent":2, "disturbing":3, "hateful":4, "political": 5} 27 | unsafe_contents = list(encode_labels.keys())[1:] 28 | 29 | class BinaryAnnotatedDataset(torch.utils.data.Dataset): 30 | def __init__(self, images_dir, labels_dir, split="train", head=None, train_test_split=0.4): 31 | 32 | labels_df = pd.read_excel(labels_dir) 33 | images, labels = [], [] 34 | for i in labels_df.index: 35 | images.append(f"{images_dir}/{i}.png") 36 | label = labels_df.loc[i, "final_label"] 37 | 38 | raw_labels = [] 39 | for annotator in ["rater_0", "rater_1", "rater_2"]: 40 | _label = labels_df.loc[i, annotator] 41 | _label = [int(l) for l in str(_label).split(",")] 42 | raw_labels.extend(_label) 43 | label_collection = Counter(raw_labels).most_common() 44 | label_collection_dict = {} 45 | for l, n in label_collection: 46 | label_collection_dict[l] = n 47 | if head: 48 | target_label = encode_labels[head] 49 | try: 50 | if int(label_collection_dict[target_label]) >= 2: 51 | label = 1 52 | except: 53 | label = 0 54 | 55 | labels.append(label) 56 | 57 | images_train, images_test, labels_train, labels_test = model_selection.train_test_split(images, labels, \ 58 | test_size=train_test_split, 59 | shuffle=True, 60 | random_state=1) 61 | if split == "train": 62 | self.images = images_train 63 | self.labels = labels_train 64 | elif split == "test": 65 | self.images = images_test 66 | self.labels = labels_test 67 | 68 | def __getitem__(self, idx): 69 | return self.images[idx], self.labels[idx] 70 | 71 | def __len__(self): 72 | return len(self.images) 73 | 74 | def weights(self): 75 | count = Counter(self.labels) 76 | class_count = np.array([count[0], count[1]]) 77 | weight = 1.0/class_count 78 | weights = np.array([weight[0] if la==0 else weight[1] for la in self.labels]) 79 | return weights 80 | 81 | 82 | class MHSafetyClassifier(torch.nn.Module): 83 | def __init__(self, device, model_name, pretrained): 84 | super(MHSafetyClassifier, self).__init__() 85 | self.clip_model, self.preprocess, _ = open_clip.create_model_and_transforms(model_name, pretrained) 86 | self.clip_model.to(device) 87 | self.projection_head = nn.Sequential( 88 | nn.Linear(768, 384), 89 | nn.ReLU(), 90 | nn.Dropout(0.5), 91 | nn.BatchNorm1d(384), 92 | nn.Linear(384, 1) 93 | ).to(device) 94 | 95 | def freeze(self): 96 | self.clip_model = self.clip_model.eval() 97 | for param in self.clip_model.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, x): 101 | x = self.clip_model.encode_image(x).type(torch.float32) 102 | x = self.projection_head(x) 103 | out = nn.Sigmoid()(x) 104 | return out 105 | 106 | def train(opt, record=True): 107 | 108 | EPOCH = config.epoch 109 | LR = config.learning_rate 110 | BATCH_SIZE = config.batch_size 111 | 112 | model_name, pretrained = config.model_name, config.pretrained 113 | output_dir = opt.output_dir 114 | Path(output_dir).mkdir(exist_ok=True, parents=True) 115 | 116 | for head in unsafe_contents: 117 | if record: 118 | logging.getLogger('').handlers = [] 119 | logging.basicConfig(level=logging.INFO, filename=f"{output_dir}/{head}.log") 120 | 121 | 122 | trainset = BinaryAnnotatedDataset(images_dir=opt.images_dir, labels_dir=opt.labels_dir, split="train", head=head) 123 | sampler = WeightedRandomSampler(trainset.weights(), num_samples=trainset.weights().shape[0], replacement=True) 124 | testset = BinaryAnnotatedDataset(images_dir=opt.images_dir, labels_dir=opt.labels_dir, split="test", head=head) 125 | 126 | train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, drop_last=True, sampler=sampler) 127 | test_loader = DataLoader(testset, batch_size=20, drop_last=False) 128 | 129 | model = MHSafetyClassifier(device, model_name, pretrained) 130 | model.freeze() # freeze the backbone 131 | criterion = nn.BCELoss() 132 | optimizer = torch.optim.Adam(model.projection_head.parameters(), lr=LR) 133 | 134 | best_score = 0.0 135 | for epoch in range(EPOCH): 136 | model.projection_head.train() 137 | total_loss = 0.0 138 | ground_truth, prediction = [], [] 139 | for idx, (imgs, labels) in enumerate(train_loader): 140 | labels = labels.to(device) 141 | labels = labels.type(torch.float32) 142 | images = [model.preprocess(PIL.Image.open(img_path)) for img_path in imgs] 143 | images = torch.stack(images).to(device) # [b_s, 3, 224, 224] 144 | logits = model(images).squeeze() 145 | loss = criterion(logits, labels) 146 | optimizer.zero_grad() 147 | loss.backward() 148 | optimizer.step() 149 | preds = (logits.detach().cpu()>0.5).to(dtype=torch.int64) 150 | ground_truth.append(labels.detach().cpu()) 151 | prediction.append(preds) 152 | # print(loss) 153 | avg_loss = total_loss/(idx+1) 154 | ground_truth = torch.hstack(ground_truth) 155 | prediction = torch.hstack(prediction) 156 | accuracy = metrics.accuracy_score(ground_truth.numpy(), prediction.numpy()) 157 | if record: 158 | logging.info(f"[epoch] {epoch} [train accuracy] {accuracy} [loss] {loss}") 159 | 160 | #================= eval ================= 161 | test_ground_truth, test_prediction = [], [] 162 | model.projection_head.eval() 163 | for idx, (imgs, labels) in enumerate(test_loader): 164 | labels = labels.to(device) 165 | images = [model.preprocess(PIL.Image.open(img_path)) for img_path in imgs] 166 | images = torch.stack(images).to(device) # [b_s, 3, 224, 224] 167 | logits = model(images).squeeze() 168 | preds = (logits.detach().cpu()>0.5).to(dtype=torch.int64) 169 | test_ground_truth.append(labels.detach().cpu()) 170 | test_prediction.append(preds) 171 | 172 | test_ground_truth = torch.hstack(test_ground_truth) 173 | test_prediction = torch.hstack(test_prediction) 174 | accuracy = metrics.accuracy_score(test_ground_truth.numpy(), test_prediction.numpy()) 175 | precision = metrics.precision_score(test_ground_truth.numpy(), test_prediction.numpy()) 176 | recall = metrics.recall_score(test_ground_truth.numpy(), test_prediction.numpy()) 177 | f1_score = metrics.f1_score(test_ground_truth.numpy(), test_prediction.numpy()) 178 | 179 | print(f"{head} test-performance: [accuracy] {accuracy} [precision] {precision} [recall] {recall} [f1_score] {f1_score} \n") 180 | 181 | if accuracy > best_score: 182 | best_score = accuracy 183 | torch.save(model.projection_head.state_dict(), f"{output_dir}/{head}.pt") 184 | print(f"best model - {head} {epoch} {best_score}") 185 | 186 | if __name__=="__main__": 187 | 188 | parser = argparse.ArgumentParser() 189 | 190 | parser.add_argument( 191 | "--images_dir", 192 | type=str, 193 | nargs="?", 194 | default=None, 195 | help="images folder" 196 | ) 197 | parser.add_argument( 198 | "--labels_dir", 199 | type=str, 200 | nargs="?", 201 | default=None, 202 | help="the directory saved prompts" 203 | ) 204 | parser.add_argument( 205 | "--output_dir", 206 | type=str, 207 | nargs="?", 208 | help="dir to write results to", 209 | default=None 210 | ) 211 | 212 | opt = parser.parse_args() 213 | 214 | train(opt, record=True) --------------------------------------------------------------------------------