├── README.md ├── laion-webdataset-pt.ipynb ├── laionsafety.py └── laion-webdataset-tf.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # LAION-SAFETY 2 | A open toolbox for NSFW & toxicity detection 3 | 4 | # Overview 5 | We present a NSFW image-text-pair classifcation ensemble, which consists of an image classifier ( based on EfficientNet V2, B2 260x260, https://github.com/google/automl/tree/master/efficientnetv2 ) combined with Detoxify ( https://github.com/unitaryai/detoxify ), an existing language model for toxicity detection. 6 | 7 | The image classifier had been trained on 682550 images from the 5 classes "Drawing" (39026), "Hentai" (28134), "Neutral" (369507), "Porn" (207969) & "Sexy" (37914). 8 | 9 | To evaluate the performance of the image classifier together with & without additional information from Detoxify, we created a manually inspected test set that consists of 4900 samples, that contains images & their captions. 10 | 11 | ![image](https://cdn.discordapp.com/attachments/893170386030694460/908071613520560160/unknown.png) 12 | 13 | To use our 5 class image classifier as a binary SFW - NSFW classifier, we consider images from the classes "Drawing" & "Neutral" as SFW and "Hentai", "Porn" & "Sexy" as NSFW. 14 | 15 | --> Our image classifier predicts 96,45 % of the true NSFW correctly as NSFW and discards 7,96 % of the SFW images incorrectly as NSFW. 16 | 17 | 18 | False negatives: 3,55% 19 | 20 | False positives: 7,96% 21 | 22 | 23 | We compare our model with the best NSFW classifier from the Github user Gantman (https://github.com/GantMan/nsfw_model , Inception V3, Keras 299x299), to our knowledge the best openly available NSFW classifier at the time of writing: 24 | 25 | ![image](https://cdn.discordapp.com/attachments/893170386030694460/905489671654613102/unknown.png) 26 | 27 | 28 | False negatives: 5,90% 29 | 30 | False positives: 7,52% 31 | 32 | --> Our image classifier predicts ~ 2 % less false negatives, at the cost of predicting ~0,5% more SFW pictures as NSFW. 33 | Because reducing the percentage of false negatives is more important in most contexts, the slightly increased percentage of false positives should be acceptable in most use cases. 34 | 35 | 36 | To leverage the information from the image captions, we add the sum of Detoxify's "toxicity" & "sexual_explicity" scores to the softmax scores of the image classifier before determining the category with the highest score. 37 | 38 | This ensemble archives the following performance: 39 | 40 | ![image](https://cdn.discordapp.com/attachments/893170386030694460/908072103465599026/unknown.png) 41 | 42 | 43 | False negatives: 2,22% 44 | 45 | False positives: 5,33% 46 | 47 | --> This ensemble predicts 1,3 % less false negatives & 2,6 % less false positives than our image classifier alone. 48 | 49 | # Preparations 50 | 51 | Here can you download our EfficientNet V2 based NSFW image classifier: 52 | https://drive.google.com/file/d/1NkDsWtjyMrak5dnw8JYJBwovifIB89uS/view?usp=sharing 53 | 54 | It needs to be in the same directory as the inference script. 55 | 56 | Install Detoxify with: 57 | pip install detoxify 58 | 59 | 60 | # Inference 61 | 62 | 63 | 64 | # Training 65 | 66 | 67 | 68 | # Disclaimer 69 | Even though this is obvious, we explicitly state here that the predictions made by our image classifier & its ensemble with Detoxify are not 100% correct & that everyone who applies them has to take the full responsibilty for this application. 70 | -------------------------------------------------------------------------------- /laion-webdataset-pt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fb566193-5405-4dcf-9cef-ae30c5919a4e", 6 | "metadata": {}, 7 | "source": [ 8 | "## References\n", 9 | "\n", 10 | "* https://github.com/webdataset/webdataset-tensorflow/blob/main/resnet-multi.py\n", 11 | "* https://github.com/LAION-AI/LAION-SAFETY/blob/main/laionsafety.py\n", 12 | "\n", 13 | "## Machine used\n", 14 | "\n", 15 | "00071-gpu machine on Spell." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "id": "f0fdde1b-12c7-4ae4-a5e4-f30a180c4152", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import torch\n", 26 | "import numpy as np\n", 27 | "import pprint\n", 28 | "import webdataset as wds\n", 29 | "from webdataset import multi\n", 30 | "import typer\n", 31 | "\n", 32 | "import os" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "444a2022-4060-4dbb-97a8-e3f49efd0223", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "PATH = \"../dataset/{00000..00046}.tar\"\n", 43 | "BATCH_SIZE = 512" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "id": "51dd32f9-09b7-444c-bc99-4cfe8e19589b", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available.\n", 54 | " if \"txt\" not in item:\n", 55 | " return False\n", 56 | " if \"jpg\" not in item:\n", 57 | " return False\n", 58 | " return True" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "id": "b1deea1a-6a7e-4942-b62e-47821002fa93", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "dataset = wds.WebDataset(PATH, handler=wds.ignore_and_continue).select(filter_dataset).decode('rgb').to_tuple('jpg', 'txt')\n", 69 | "dataloader = wds.WebLoader(dataset, shuffle=False, num_workers=os.cpu_count(), batch_size=BATCH_SIZE, prefetch_factor=4*BATCH_SIZE)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "id": "6dc038ec-ac80-41cf-9c4f-da2a82a98f6c", 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "torch.Size([512, 256, 256, 3])\n", 83 | "512\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "for image_batch, caption_batch in dataloader:\n", 89 | " print(image_batch.shape)\n", 90 | " print(len(caption_batch))\n", 91 | " break" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 6, 97 | "id": "c78827b6-2450-431c-afa0-53be81c07cf0", 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "CPU times: user 1.22 s, sys: 498 ms, total: 1.72 s\n", 105 | "Wall time: 5.58 s\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "%%time\n", 111 | "\n", 112 | "# Benchmaring for 10 batches.\n", 113 | "for i, (batch, label) in enumerate(dataloader):\n", 114 | " if i == 10:\n", 115 | " break" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "environment": { 121 | "kernel": "python3", 122 | "name": "tf2-gpu.2-7.m87", 123 | "type": "gcloud", 124 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87" 125 | }, 126 | "kernelspec": { 127 | "display_name": "Python 3 (ipykernel)", 128 | "language": "python", 129 | "name": "python3" 130 | }, 131 | "language_info": { 132 | "codemirror_mode": { 133 | "name": "ipython", 134 | "version": 3 135 | }, 136 | "file_extension": ".py", 137 | "mimetype": "text/x-python", 138 | "name": "python", 139 | "nbconvert_exporter": "python", 140 | "pygments_lexer": "ipython3", 141 | "version": "3.8.0" 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 5 146 | } 147 | -------------------------------------------------------------------------------- /laionsafety.py: -------------------------------------------------------------------------------- 1 | image_size =260 # resolution of the image classifier 2 | 3 | batchsize=1024 #batchsize for inference. Lower if you get OOM errors 4 | 5 | datadir = "./laion400m-dat-release/" # dir where the tar files are located 6 | SHARDS = "{00000..00002}.tar" # format of the tar files 7 | 8 | 9 | targetdir1= "./drawings/" 10 | targetdir2= "./hentai/" 11 | targetdir3= "./neutral/" 12 | targetdir4= "./porn/" 13 | targetdir5= "./sexy/" 14 | import time 15 | import os 16 | 17 | try: 18 | os.mkdir(targetdir1) 19 | os.mkdir(targetdir2) 20 | os.mkdir(targetdir3) 21 | os.mkdir(targetdir5) 22 | os.mkdir(targetdir4) 23 | 24 | except: 25 | pass 26 | 27 | import webdataset as wds 28 | 29 | from webdataset import multi 30 | 31 | import cv2 32 | from tqdm import tqdm 33 | from PIL import Image 34 | import time 35 | import uuid 36 | 37 | 38 | import itertools 39 | 40 | import matplotlib.pylab as plt 41 | import numpy as np 42 | import imageio 43 | import glob 44 | 45 | import time 46 | from detoxify import Detoxify 47 | import multiprocessing 48 | from multiprocessing import Process , Manager 49 | 50 | import webdataset as wds 51 | import torch 52 | import tensorflow as tf 53 | 54 | def get_class_string_from_index(index): 55 | for class_string, class_index in generator.class_indices.items(): 56 | if class_index == index: 57 | return class_string 58 | 59 | def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available. 60 | if 'txt' not in item: 61 | return False 62 | if 'jpg' not in item: 63 | return False 64 | return True 65 | 66 | # pack image inference into its own process, to make sure all GPU memory is freed afterwards for the Detoxify inference 67 | def image_classifier(caption_list,prediction_list,datadir): 68 | 69 | 70 | from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy 71 | from tensorflow.python.data.ops.dataset_ops import _NumpyIterator as NumpyIterator 72 | import tensorflow as tf 73 | import tensorflow_hub as hub 74 | 75 | ds = wds.WebDataset(datadir+SHARDS, handler=wds.ignore_and_continue).select(filter_dataset).decode('rgb').to_tuple('jpg', 'txt') 76 | 77 | dl = wds.WebLoader(ds, shuffle=False, num_workers=16, batch_size=batchsize, prefetch_factor=4*batchsize) #, prefetch_factor=4*batchsize, pin_memory=True 78 | c=0 79 | start =time.time() 80 | 81 | model = tf.keras.models.load_model('nsfweffnetv2-b02-3epochs.h5',custom_objects={"KerasLayer":hub.KerasLayer}) 82 | os.system("nvidia-smi") 83 | # Show the model architecture 84 | #model.summary() 85 | 86 | 87 | c=0 88 | start = time.time() 89 | 90 | print("starting loader") 91 | for im_arr, txt in dl: 92 | start = time.time() 93 | c+=1 94 | im_arr = tf.image.resize(im_arr, [260,260], antialias=True) 95 | #print (im_arr.shape) 96 | prediction_scores = model.predict(im_arr) 97 | prediction_list.append(prediction_scores) 98 | captions= [] 99 | txt_list = list (txt) 100 | for e in txt_list: 101 | captions.append(e[:200]) # captions are cut off after 200 characters, to avoid OOM errors 102 | 103 | 104 | caption_list.append(captions) 105 | print(c) 106 | print("image predition time") 107 | print( time.time()-start) 108 | del model 109 | tf.keras.backend.clear_session() 110 | 111 | 112 | start =time.time() 113 | 114 | n_drawings =0 115 | n_hentai =0 116 | n_neutral =0 117 | n_porn =0 118 | n_sexy =0 119 | manager = Manager() 120 | prediction_list= manager.list() 121 | caption_list= manager.list() 122 | p=[] 123 | p.append(Process(target=image_classifier, args=(caption_list,prediction_list, datadir ))) 124 | p[0].start() 125 | p[0].join() 126 | 127 | 128 | model_txt = Detoxify('multilingual', device='cuda') 129 | os.system("nvidia-smi") 130 | 131 | for i in range(len(caption_list)): 132 | #start = time.time() 133 | #print(type(caption_list[i])) 134 | 135 | text_res = model_txt.predict(caption_list[i]) 136 | 137 | predicted_indices =[] 138 | for j in range(len(caption_list[i])): 139 | 140 | predicted_indices.append( np.argmax(prediction_list[i][j])) 141 | #print(prediction_list[i].shape) 142 | dist = np.array(tf.nn.softmax(prediction_list[i][j])) 143 | dist[1]=dist[1]+text_res["sexual_explicit"][j] + text_res["toxicity"][j] 144 | dist[3]=dist[3]+text_res["sexual_explicit"][j] + text_res["toxicity"][j] 145 | dist[4]=dist[4]+text_res["sexual_explicit"][j] + text_res["toxicity"][j] 146 | 147 | predicted_index = np.argmax(dist) 148 | #print("predicted_index") 149 | #print(predicted_index) 150 | if predicted_index==0: 151 | #imageio.imwrite(targetdir1+str(n_drawings+100000000)+".jpg", im_arr[j]) #content/nsfw_data_scraper/data/train/porn/ 152 | n_drawings +=1 153 | #print("n_drawings: "+str(n_drawings)) 154 | if predicted_index==1: 155 | #imageio.imwrite(targetdir2+str(n_hentai+100000000)+".jpg", im_arr[j]) #content/nsfw_data_scraper/data/train/porn/ 156 | n_hentai +=1 157 | #print("n_hentai: "+str(n_hentai)) 158 | if predicted_index==2: 159 | #imageio.imwrite(targetdir3+str(n_neutral+100000000)+".jpg", im_arr[j]) #content/nsfw_data_scraper/data/train/porn/ 160 | n_neutral +=1 161 | #print("n_neutral: "+str(n_neutral)) 162 | if predicted_index==3: 163 | #imageio.imwrite(targetdir4+str(n_porn+100000000)+".jpg", im_arr[j]) #content/nsfw_data_scraper/data/train/porn/ 164 | n_porn +=1 165 | #print("n_porn: "+str(n_porn)) 166 | if predicted_index==4: 167 | #imageio.imwrite(targetdir5+str(n_sexy+100000000)+".jpg", im_arr[j]) #content/nsfw_data_scraper/data/train/porn/ 168 | n_sexy +=1 169 | #print("n_sexy: "+str(n_sexy)) 170 | print(i) 171 | #print("txt predition time") 172 | #print( time.time()-start) 173 | 174 | #start = time.time() 175 | 176 | print("n_drawings: "+str(n_drawings)) 177 | print("n_hentai: "+str(n_hentai)) 178 | print("n_neutral: "+str(n_neutral)) 179 | print("n_porn: "+str(n_porn)) 180 | print("n_sexy: "+str(n_sexy)) 181 | print( time.time()-start) 182 | -------------------------------------------------------------------------------- /laion-webdataset-tf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fb566193-5405-4dcf-9cef-ae30c5919a4e", 6 | "metadata": {}, 7 | "source": [ 8 | "## References\n", 9 | "\n", 10 | "* https://github.com/webdataset/webdataset-tensorflow/blob/main/resnet-multi.py\n", 11 | "* https://github.com/LAION-AI/LAION-SAFETY/blob/main/laionsafety.py\n", 12 | "\n", 13 | "## Machine used\n", 14 | "\n", 15 | "GCP n1-highmem (16 vCPUs, 104 GB RAM) with Debian 10" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "id": "f0fdde1b-12c7-4ae4-a5e4-f30a180c4152", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import tensorflow as tf\n", 26 | "import numpy as np\n", 27 | "import pprint\n", 28 | "import webdataset as wds\n", 29 | "from webdataset import multi\n", 30 | "import typer\n", 31 | "\n", 32 | "import os" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "444a2022-4060-4dbb-97a8-e3f49efd0223", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "url = \"http://3080.rom1504.fr/cah/laion400m_porn_data/{00000..00046}.tar\"\n", 43 | "url = f\"pipe:curl -L -s {url} || true\"" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "id": "51dd32f9-09b7-444c-bc99-4cfe8e19589b", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available.\n", 54 | " if \"txt\" not in item:\n", 55 | " return False\n", 56 | " if \"jpg\" not in item:\n", 57 | " return False\n", 58 | " return True" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "id": "b1deea1a-6a7e-4942-b62e-47821002fa93", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "class ImagenetData:\n", 69 | " \"\"\"This class is a convenient placeholder for the dataset-related information.\n", 70 | " You could also just define these iterator etc. as global functions.\"\"\"\n", 71 | "\n", 72 | " def __init__(self, url=url):\n", 73 | " self.url = url\n", 74 | " self.dataset = (\n", 75 | " wds.WebDataset(self.url, shardshuffle=True, handler=wds.ignore_and_continue)\n", 76 | " .select(filter_dataset)\n", 77 | " .decode(\"rgb\")\n", 78 | " .to_tuple(\"jpg\", \"txt\")\n", 79 | " )\n", 80 | " self.loader = multi.MultiLoader(self.dataset, workers=os.cpu_count())\n", 81 | "\n", 82 | " def __iter__(self):\n", 83 | " for img, hot in self.loader:\n", 84 | " yield img.astype(\"float32\"), np.array(hot).astype(str)\n", 85 | "\n", 86 | " def output_shapes(self):\n", 87 | " return ((256, 256, 3), ())\n", 88 | "\n", 89 | " def output_types(self):\n", 90 | " return (tf.float32, tf.string)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 6, 96 | "id": "1a201204-c319-422f-8d6c-a658c3bcd350", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "def resize_images(image_batch, caption_batch):\n", 101 | " return tf.image.resize(image_batch, (260, 260), antialias=True), caption_batch" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 7, 107 | "id": "4011b920-604e-4e23-b163-7ee8a3e2dbc6", 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stderr", 112 | "output_type": "stream", 113 | "text": [ 114 | "2022-01-08 06:14:41.848510: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "options = tf.data.Options()\n", 120 | "options.experimental_distribute.auto_shard_policy = (\n", 121 | " tf.data.experimental.AutoShardPolicy.DATA\n", 122 | ")\n", 123 | "options.experimental_optimization.noop_elimination = True\n", 124 | "options.experimental_optimization.apply_default_optimizations = True\n", 125 | "options.experimental_optimization.filter_fusion = True\n", 126 | "\n", 127 | "df = ImagenetData()\n", 128 | "tdf = tf.data.Dataset.from_generator(\n", 129 | " generator=df.__iter__,\n", 130 | " output_types=df.output_types(),\n", 131 | " output_shapes=df.output_shapes(),\n", 132 | ")\n", 133 | "tdf = tdf.with_options(options)\n", 134 | "tdf = tdf.batch(512).map(resize_images, num_parallel_calls=tf.data.AUTOTUNE)\n", 135 | "tdf = tdf.prefetch(tf.data.AUTOTUNE)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "id": "6dc038ec-ac80-41cf-9c4f-da2a82a98f6c", 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "(512, 260, 260, 3)\n", 149 | "(512,)\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "for image_batch, caption_batch in tdf.take(1):\n", 155 | " print(image_batch.shape)\n", 156 | " print(caption_batch.shape)\n", 157 | " break" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 9, 163 | "id": "c78827b6-2450-431c-afa0-53be81c07cf0", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "killing \n", 171 | "killing \n", 172 | "killing \n", 173 | "killing \n", 174 | "killing \n", 175 | "killing \n", 176 | "killing \n", 177 | "killing \n", 178 | "killing \n", 179 | "killing \n", 180 | "killing \n", 181 | "killing \n", 182 | "killing \n", 183 | "killing \n", 184 | "killing \n", 185 | "killing \n", 186 | "closing \n", 187 | ".\n", 188 | "CPU times: user 25.1 s, sys: 10.7 s, total: 35.8 s\n", 189 | "Wall time: 13.6 s\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "%%time\n", 195 | "\n", 196 | "# Benchmaring for 10 batches.\n", 197 | "for i, (batch, label) in enumerate(tdf.take(10)):\n", 198 | " if i % 40 == 0:\n", 199 | " print(\".\", end=\"\")\n", 200 | "print()" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "environment": { 206 | "kernel": "python3", 207 | "name": "tf2-gpu.2-7.m87", 208 | "type": "gcloud", 209 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87" 210 | }, 211 | "kernelspec": { 212 | "display_name": "Python 3 (ipykernel)", 213 | "language": "python", 214 | "name": "python3" 215 | }, 216 | "language_info": { 217 | "codemirror_mode": { 218 | "name": "ipython", 219 | "version": 3 220 | }, 221 | "file_extension": ".py", 222 | "mimetype": "text/x-python", 223 | "name": "python", 224 | "nbconvert_exporter": "python", 225 | "pygments_lexer": "ipython3", 226 | "version": "3.8.2" 227 | } 228 | }, 229 | "nbformat": 4, 230 | "nbformat_minor": 5 231 | } 232 | --------------------------------------------------------------------------------