├── README.md ├── install.py └── scripts └── censor.py /README.md: -------------------------------------------------------------------------------- 1 | A NSFW checker for [Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). Replaces non-worksafe images with black squares. Install it from UI. 2 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | if not launch.is_installed("diffusers"): 4 | launch.run_pip(f"install diffusers", "diffusers") 5 | -------------------------------------------------------------------------------- /scripts/censor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 3 | from transformers import AutoFeatureExtractor 4 | from PIL import Image 5 | 6 | from modules import scripts, shared 7 | 8 | safety_model_id = "CompVis/stable-diffusion-safety-checker" 9 | safety_feature_extractor = None 10 | safety_checker = None 11 | 12 | 13 | def numpy_to_pil(images): 14 | """ 15 | Convert a numpy image or a batch of images to a PIL image. 16 | """ 17 | if images.ndim == 3: 18 | images = images[None, ...] 19 | images = (images * 255).round().astype("uint8") 20 | pil_images = [Image.fromarray(image) for image in images] 21 | 22 | return pil_images 23 | 24 | 25 | # check and replace nsfw content 26 | def check_safety(x_image): 27 | global safety_feature_extractor, safety_checker 28 | 29 | if safety_feature_extractor is None: 30 | safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) 31 | safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) 32 | 33 | safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") 34 | x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) 35 | 36 | return x_checked_image, has_nsfw_concept 37 | 38 | 39 | def censor_batch(x): 40 | x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() 41 | x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) 42 | x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) 43 | 44 | return x 45 | 46 | 47 | class NsfwCheckScript(scripts.Script): 48 | def title(self): 49 | return "NSFW check" 50 | 51 | def show(self, is_img2img): 52 | return scripts.AlwaysVisible 53 | 54 | def postprocess_batch(self, p, *args, **kwargs): 55 | images = kwargs['images'] 56 | images[:] = censor_batch(images)[:] 57 | --------------------------------------------------------------------------------