├── .gitattributes ├── .gitignore ├── README.md ├── characters ├── nsfw.txt └── sfw.txt ├── clipcrop.py ├── configs └── upscaler.yaml ├── dbimutils.py ├── file_manager.py ├── install.py ├── interrogators ├── blip_interrogator.py ├── booru_interrogator.py ├── data │ ├── artists.txt │ ├── flavors.txt │ ├── mediums.txt │ └── movements.txt ├── idefics2_interrogator.py ├── interrogator.py ├── llava2_interrogator.py ├── moondream_interrogator.py └── wolf_interrogator.py ├── javascript └── smart_process.js ├── model_download.py ├── mplug_owl2 ├── __init__.py ├── constants.py ├── conversation.py ├── evaluate │ ├── EVALUATION.md │ ├── __init__.py │ ├── evaluate_caption.py │ ├── evaluate_mmbench.py │ ├── evaluate_mme.py │ ├── evaluate_mmmu.py │ ├── evaluate_vqa.py │ ├── mmbench_converter.py │ ├── vqa.py │ └── vqa_eval.py ├── local_serve │ ├── __init__.py │ ├── examples │ │ ├── Rebecca_(1939_poster)_Small.jpeg │ │ └── extreme_ironing.jpg │ ├── local_web_server.py │ └── model_worker.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── configuration_mplug_owl2.py │ ├── configuration_qwen.py │ ├── convert_mplug_owl2_weight_to_hf.py │ ├── modeling_attn_mask_utils.py │ ├── modeling_llama2.py │ ├── modeling_mplug_owl2.py │ ├── modeling_qwen.py │ ├── multiway.py │ ├── utils.py │ └── visual_encoder.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── Rebecca_(1939_poster)_Small.jpeg │ │ └── extreme_ironing.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ └── register_workers.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── mplug_owl2_trainer.py │ ├── train.py │ └── train_mem.py └── utils.py ├── process_params.py ├── processors.py ├── requirements.txt ├── scripts └── process_main.py ├── smartprocess.py ├── style.css ├── super_resolution.py └── upscalers └── spandrel ├── spandrel_srformer_model.py └── spandrel_upscaler_base.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion WebUI Smart Pre-Processing Extension 2 | 3 | ## What is this?? 4 | 5 | As the name would imply, this is an extension for 6 | the [Stable-Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) by @Automatic1111 7 | 8 | ## What does this do? 9 | 10 | It does a few things, actually. 11 | 12 | For starters, it utilizes a combination of BLIP/CLIP and YOLOv5 to provide "smart cropping" for images. The primary 13 | subject of each image is identified, the center of that subject is determined, and then the application tries it's best 14 | to crop the image so as to keep as much of the subject as possible within the dimensions specified. 15 | 16 | Second, it allows storing the determined image caption directly to the image filename, versus having to create a txt 17 | file along side every image. You can still create a txt file, use existing captions, or not do any captioning at all. 18 | 19 | Third, I've provided face restoration and upscaling options for input images. You can select from GFPGAN and Codeformer 20 | for face restoration, and any of the provided upscalers from the "extras' tab to refine/smooth/add detail to your final 21 | output images. 22 | 23 | Last, but not least, it offers a rudimentary way to swap the "class" of a captioned image with the specific keyword in 24 | the image. So, if you're trying to train a subject called "xyz" and "xyz" is a dog, you can easily swap "dog" (and "a 25 | dog") wth "xyz" in your captions. Neato! 26 | 27 | ## Smart Cropping 28 | 29 | As I said above, smart cropping utilizes a combination of YOLOV5 object recognition and BLIP/CLIP (and DeepDanBooru) 30 | captioning to automatically determine the most prominent subject in a photo, and automatically crop the subject as 31 | completely as possible. You can also specify a specific subject (dog/cat/woman/house) for the software to find, and skip 32 | the YOLOV5 detection entirely. 33 | 34 | 35 | 36 | If a subject is not found, the image will be downscaled and cropped from the center. 37 | 38 | ## Smart Captioning 39 | 40 | This uses all the same features as set in user preferences, with the additional options to save to txt or append to the 41 | image file name. 42 | 43 | Additionally, you can swap the generic "class" of the image with a specific subject keyword. This feature may not be 44 | perfect in all cases, but it should still go a long way in speeding up the captioning process. 45 | 46 | You can also specify a maximum caption length, which will split the caption by spaces and append words until the maximum 47 | length is reached. 48 | 49 | ## Post Processing 50 | 51 | It's basically a simplified version of the "extras" tab. The idea is that you can do facial restoration and/or use a 52 | model like swinIR or LDSR to smooth or add details to an image. If an image is "actually" upscaled beyond the target 53 | crop size, it will be downscaled again back to the original size. 54 | 55 | -------------------------------------------------------------------------------- /characters/nsfw.txt: -------------------------------------------------------------------------------- 1 | Human: 2 | Human: Give a short one sentence caption for this image with NO punctuation. DO NOT USE ANY PUNCTUATION OR COMMAS. DO NOT USE COMMAS!!! 3 | AI: -------------------------------------------------------------------------------- /characters/sfw.txt: -------------------------------------------------------------------------------- 1 | Revised and SFW prompt: 2 | 3 | Act like a stable diffusion SD prompt generator, your focus is to convey warm and playful scenarios in vivid detail, include lush descriptions, the quality of your prompts must be only the best creative and varied. Your subjects must be adults only. When you see "young" in a prompt, assume the subject is in their early 20s. The SD prompts must be formatted in a way the AI can understand. 4 | 5 | Stable Diffusion prompts are made of components which are comprised of keywords separated by commas, keywords can be single words or multi-word keywords, and they have a specific order. A typical format for the components looks like this: [Adjectives], [Type], [Framing], [Shot], [subject], [Descriptor], [Descriptor], [Descriptor], [Expression], [Pose], [Action], [Environment], [Details], [Lighting], [Medium], [Aesthetics], [Visual], [Clothing]. 6 | 7 | Here are some keywords commonly used for each of the components, but always mix them with new ones that are coherent to each component: 8 | 9 | Adjectives: Masterpiece, award-winning, Most beautiful, Epic, Majestic, best quality, fantastic, highly detailed, ultra-detailed, 8K, realistic 10 | Type: photo, selfie, 8k photo, professional photo 11 | Framing: Dutch angle, Wide Angle, low angle, high angle, perspective, on eye level, POV, close-up 12 | Shot: mid shot, full shot, portrait, from behind, long shot, cowboy shot, from above, from below 13 | Subject: woman, man 14 | Descriptor: elegant, well-dressed, casual attire, formal attire, stylish, modern, vintage, sporty, artistic, natural, sophisticated, trendy, chic 15 | Expression: satisfied, happy, excited, surprised, mouth open, eyes closed, sleeping, smiling, frowning, crying 16 | Pose: seated, standing, laying down, hands raised, walking, jumping, dancing, running 17 | Action: laughing, chatting, cooking, reading, painting, dancing, sleeping, crying, sitting, standing, lying down 18 | Environment: park, cafe, library, office, beach, mall, rooftop, forest, kitchen, bedroom, outer space, mountains, cityscape 19 | Details: Cloudless sky, glittering night, sparkling rain, shining lights, obscure darkness, red lights, neon lighting 20 | Lighting: light, dim light, two-tone lighting, dynamic lighting, rim light, studio light, diffuse lighting, dark lighting, soft lighting, fantasy lighting, natural light 21 | Medium: photograph, painting, illustration, sketch, render 22 | Aesthetics: vintage, retro, modern, rustic, minimalist, bohemian, industrial, romantic, fantasy, sci-fi, noir, gothic, ethereal 23 | Visual: contrast, cyan hue, Kodachrome, Fujifilm Superia, warm colours, saturation, vibrance, filters coolness, chromatic aberration, cinematic, RAW color 24 | Clothing: casual, formal, vintage, modern, sporty, chic, elegant, streetwear, boho, classic, punk, gothic, trendy 25 | 26 | Use the components to build coherent prompts. Use these keywords but also create your own. Generate variations of the keywords that are coherent to each component and fit the instruction. Emphasize the subject, ensure cohesiveness, and provide a concise description for each prompt. Only reply with full single prompts separated by a line break, do not add a numbered list, quotes, or a section breakdown. 27 | 28 | For example: 29 | 30 | [a photo of a beautiful woman standing outside at night, masterpiece, professional photo, Dutch angle, full shot, woman, elegant, stylish, happy, standing, laughing, park, glittering night, light, photograph, vintage, contrast, casual] 31 | 32 | Human: 33 | Human: Caption this image. 34 | AI: -------------------------------------------------------------------------------- /clipcrop.py: -------------------------------------------------------------------------------- 1 | # Original project: https://github.com/Vishnunkumar/clipcrop/blob/main/clipcrop/clipcrop.py 2 | import os.path 3 | import sys 4 | 5 | import cv2 6 | import numpy 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from clip import clip 11 | 12 | import modules.paths 13 | from modules import shared, modelloader 14 | 15 | 16 | def clip_boxes(boxes, shape): 17 | # Clip boxes (xyxy) to image shape (height, width) 18 | if isinstance(boxes, torch.Tensor): # faster individually 19 | boxes[:, 0].clamp_(0, shape[1]) # x1 20 | boxes[:, 1].clamp_(0, shape[0]) # y1 21 | boxes[:, 2].clamp_(0, shape[1]) # x2 22 | boxes[:, 3].clamp_(0, shape[0]) # y2 23 | else: # np.array (faster grouped) 24 | boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 25 | boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 26 | 27 | 28 | def find_position(parent: Image, child: Image): 29 | w = child.width 30 | h = child.height 31 | res = cv2.matchTemplate(np.array(parent), np.array(child), cv2.TM_CCOEFF_NORMED) 32 | min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) 33 | # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum 34 | top_left = max_loc 35 | center_x = top_left[0] + (w / 2) 36 | center_y = top_left[1] + (h / 2) 37 | return center_x, center_y 38 | 39 | 40 | class CropClip: 41 | def __init__(self): 42 | # Model 43 | model_name = 'yolov5m6v7.pt' 44 | model_url = 'https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m6.pt' 45 | model_dir = os.path.join(modules.paths.models_path, "yolo") 46 | model_path = modelloader.load_models(model_dir, model_url, None, '.pt', model_name) 47 | was_safe_unpickle = shared.cmd_opts.disable_safe_unpickle 48 | shared.cmd_opts.disable_safe_unpickle = True 49 | self.model = torch.hub.load('ultralytics/yolov5', 'custom', model_path[0]) 50 | self.clip = None 51 | self.preprocess = None 52 | shared.cmd_opts.disable_safe_unpickle = was_safe_unpickle 53 | # Prevent BLIP crossfire breakage 54 | try: 55 | del sys.modules['models'] 56 | except: 57 | pass 58 | 59 | def get_center(self, image: Image, prompt: str): 60 | # Load image into YOLO parser 61 | results = self.model(image) # includes NMS 62 | # Crop each image result to an array 63 | cropped = results.crop(False) 64 | l = [] 65 | for crop in cropped: 66 | l.append(Image.fromarray(crop["im"])) 67 | if len(l) == 0: 68 | l = [image] 69 | device = shared.device 70 | # Take out cropped YOLO images, and get the features? 71 | if not self.model or not self.preprocess: 72 | self.clip, self.preprocess = clip.load("ViT-B/32", device=device) 73 | images = torch.stack([self.preprocess(im) for im in l]).to(device) 74 | with torch.no_grad(): 75 | image_features = self.clip.encode_image(images) 76 | image_features /= image_features.norm(dim=-1, keepdim=True) 77 | 78 | image_features.cpu().numpy() 79 | image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() 80 | image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() 81 | 82 | images = [self.preprocess(im) for im in l] 83 | image_input = torch.tensor(np.stack(images)).cuda() 84 | image_input -= image_mean[:, None, None] 85 | image_input /= image_std[:, None, None] 86 | with torch.no_grad(): 87 | image_features = self.clip.encode_image(image_input).float() 88 | image_features /= image_features.norm(dim=-1, keepdim=True) 89 | 90 | def similarity_top(similarity_list, N): 91 | results = zip(range(len(similarity_list)), similarity_list) 92 | results = sorted(results, key=lambda x: x[1], reverse=True) 93 | top_images = [] 94 | scores = [] 95 | for index, score in results[:N]: 96 | scores.append(score) 97 | top_images.append(l[index]) 98 | return scores, top_images 99 | 100 | # @title Crop 101 | with torch.no_grad(): 102 | # Encode and normalize the description using CLIP 103 | text_encoded = self.clip.encode_text(clip.tokenize(prompt).to(device)) 104 | text_encoded /= text_encoded.norm(dim=-1, keepdim=True) 105 | 106 | # Retrieve the description vector and the photo vectors 107 | similarity = text_encoded.cpu().numpy() @ image_features.cpu().numpy().T 108 | similarity = similarity[0] 109 | scores, imgs = similarity_top(similarity, N=3) 110 | max_area = 0 111 | out = None 112 | for img in imgs: 113 | img_area = img.width * img.height 114 | if img_area > max_area: 115 | max_area = img_area 116 | out = img 117 | 118 | if not out: 119 | out = image 120 | res = cv2.matchTemplate(numpy.array(image), numpy.array(out), cv2.TM_SQDIFF) 121 | min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) 122 | # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum 123 | top_left = min_loc 124 | bottom_right = (top_left[0] + out.width, top_left[1] + out.height) 125 | return [top_left[0], bottom_right[0], top_left[1], bottom_right[1]] 126 | 127 | def unload(self): 128 | if self.model is not None: 129 | self.model.to('cpu') 130 | if self.clip: 131 | self.clip.to('cpu') 132 | if self.preprocess: 133 | self.preprocess.to('cpu') 134 | 135 | def load(self): 136 | if self.model is not None: 137 | self.model.to(shared.device) 138 | if self.clip: 139 | self.clip.to(shared.device) 140 | if self.preprocess: 141 | self.preprocess.to(shared.device) 142 | -------------------------------------------------------------------------------- /configs/upscaler.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion 4 | params: 5 | parameterization: "v" 6 | low_scale_key: "lr" 7 | linear_start: 0.0001 8 | linear_end: 0.02 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: "jpg" 13 | cond_stage_key: "txt" 14 | image_size: 128 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: "hybrid-adm" 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.08333 20 | use_ema: False 21 | 22 | low_scale_config: 23 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation 24 | params: 25 | noise_schedule_config: # image space 26 | linear_start: 0.0001 27 | linear_end: 0.02 28 | max_noise_level: 350 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | use_checkpoint: True 34 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) 35 | image_size: 128 36 | in_channels: 7 37 | out_channels: 4 38 | model_channels: 256 39 | attention_resolutions: [ 2,4,8] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 2, 4] 42 | disable_self_attentions: [True, True, True, False] 43 | disable_middle_self_attn: False 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 1024 48 | legacy: False 49 | use_linear_in_transformer: True 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | ddconfig: 56 | # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) 57 | double_z: True 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 64 | num_res_blocks: 2 65 | attn_resolutions: [ ] 66 | dropout: 0.0 67 | 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 73 | params: 74 | freeze: True 75 | layer: "penultimate" 76 | -------------------------------------------------------------------------------- /dbimutils.py: -------------------------------------------------------------------------------- 1 | # DanBooru IMage Utility functions, borrowed from 2 | # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/dbimutils.py 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def smart_imread(img, flag=cv2.IMREAD_UNCHANGED): 10 | if img.endswith(".gif"): 11 | img = Image.open(img) 12 | img = img.convert("RGB") 13 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 14 | else: 15 | img = cv2.imread(img, flag) 16 | return img 17 | 18 | 19 | def smart_24bit(img): 20 | if img.dtype is np.dtype(np.uint16): 21 | img = (img / 257).astype(np.uint8) 22 | 23 | if len(img.shape) == 2: 24 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 25 | elif img.shape[2] == 4: 26 | trans_mask = img[:, :, 3] == 0 27 | img[trans_mask] = [255, 255, 255, 255] 28 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) 29 | return img 30 | 31 | 32 | def make_square(img, target_size): 33 | old_size = img.shape[:2] 34 | desired_size = max(old_size) 35 | desired_size = max(desired_size, target_size) 36 | 37 | delta_w = desired_size - old_size[1] 38 | delta_h = desired_size - old_size[0] 39 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 40 | left, right = delta_w // 2, delta_w - (delta_w // 2) 41 | 42 | color = [255, 255, 255] 43 | new_im = cv2.copyMakeBorder( 44 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color 45 | ) 46 | return new_im 47 | 48 | 49 | def smart_resize(img, size): 50 | # Assumes the image has already gone through make_square 51 | if img.shape[0] > size: 52 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 53 | elif img.shape[0] < size: 54 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) 55 | return img 56 | -------------------------------------------------------------------------------- /file_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import List 4 | 5 | import PIL 6 | from PIL.Image import Image 7 | 8 | 9 | 10 | def clean_string(s): 11 | """ 12 | Remove non-alphanumeric characters except spaces, and normalize spacing. 13 | Args: 14 | s: The string to clean. 15 | 16 | Returns: A cleaned string. 17 | """ 18 | # Remove non-alphanumeric characters except spaces 19 | cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', s) 20 | # Check for a sentence with just the same word repeated 21 | if len(set(cleaned.split())) == 1: 22 | cleaned = cleaned.split()[0] 23 | # Replace multiple spaces with a single space 24 | cleaned = re.sub(r'\s+', ' ', cleaned).strip() 25 | return cleaned 26 | 27 | 28 | class ImageData: 29 | image_path: str = "" 30 | temp_image_path: str = "" 31 | caption: str = "" 32 | tags: List[str] = [] 33 | selected: bool = False 34 | filtered: bool = False 35 | id = None 36 | 37 | def __init__(self, image_path): 38 | self.image_path = image_path 39 | self.caption = self.read_caption() 40 | self.tags = self.split_caption() 41 | # Generate a random id 42 | self.id = os.urandom(32).hex() 43 | 44 | def read_caption(self): 45 | existing_caption_txt_filename = os.path.splitext(self.image_path)[0] + '.txt' 46 | if os.path.exists(existing_caption_txt_filename): 47 | with open(existing_caption_txt_filename, 'r', encoding="utf8") as file: 48 | existing_caption_txt = file.read() 49 | existing_caption_txt = existing_caption_txt.strip() 50 | else: 51 | image_name = os.path.splitext(os.path.basename(self.image_path))[0] 52 | existing_caption_txt = clean_string(image_name) 53 | return existing_caption_txt 54 | 55 | def split_caption(self): 56 | tags = self.caption.split(",") 57 | tags = [tag.strip() for tag in tags] 58 | tags = [tag for tag in tags if tag != ""] 59 | return tags 60 | 61 | def update_image(self, image: Image, save_file: bool = False): 62 | if save_file: 63 | img_path = os.path.splitext(self.image_path)[0] + '.png' 64 | if img_path != self.image_path and os.path.exists(self.image_path): 65 | os.remove(self.image_path) 66 | self.image_path = img_path 67 | else: 68 | self.temp_image_path = os.path.splitext(self.image_path)[0] + '_temp.png' 69 | img_path = self.temp_image_path 70 | image.save(img_path) 71 | 72 | def update_caption(self, caption: str, save_file: bool = False): 73 | if save_file: 74 | caption_txt_filename = os.path.splitext(self.image_path)[0] + '.txt' 75 | with open(caption_txt_filename, 'w', encoding="utf8") as file: 76 | file.write(caption) 77 | self.caption = caption 78 | self.tags = self.split_caption() 79 | 80 | def get_image(self): 81 | return PIL.Image.open(self.image_path).convert("RGB") 82 | 83 | 84 | class FileManager: 85 | file_path: str = "" 86 | _instance = None 87 | files: List[ImageData] = [] 88 | included_tags: List[str] = [] 89 | excluded_tags: List[str] = [] 90 | included_strings: List[str] = [] 91 | excluded_strings: List[str] = [] 92 | current_image = None 93 | 94 | def __init__(self): 95 | self.files = [] 96 | 97 | def __new__(cls): 98 | if FileManager._instance is None: 99 | FileManager._instance = object.__new__(cls) 100 | return FileManager._instance 101 | 102 | def clear(self): 103 | self.files = [] 104 | self.included_tags = [] 105 | self.excluded_tags = [] 106 | self.included_strings = [] 107 | self.excluded_strings = [] 108 | self.current_image = None 109 | 110 | def load_files(self): 111 | from extensions.sd_smartprocess.smartprocess import is_image 112 | 113 | self.clear() 114 | # Walk through all files in the directory that contains the images 115 | for root, dirs, files in os.walk(self.file_path): 116 | for file in files: 117 | file = os.path.join(root, file) 118 | if is_image(file) and "_backup" not in file: 119 | image_data = ImageData(file) 120 | self.files.append(image_data) 121 | self.update_filters() 122 | 123 | def filtered_files(self, for_gallery: bool = False): 124 | if not for_gallery: 125 | return [file for file in self.files if file.filtered] 126 | else: 127 | return [(file.image_path, file.caption) for file in self.files if file.filtered] 128 | 129 | def all_files(self, for_gallery: bool = False) -> List[ImageData]: 130 | if not for_gallery: 131 | return self.files 132 | else: 133 | return [(file.image_path, file.caption) for file in self.files] 134 | 135 | def selected_files(self, for_gallery: bool = False) -> List[ImageData]: 136 | if not for_gallery: 137 | return [file for file in self.files if file.selected] 138 | else: 139 | return [(file.image_path, file.caption) for file in self.files if file.selected] 140 | 141 | def filtered_and_selected_files(self, for_gallery: bool = False) -> List[ImageData]: 142 | if not for_gallery: 143 | return [file for file in self.files if file.filtered and file.selected and file.filtered] 144 | else: 145 | return [(file.image_path, file.caption) for file in self.files if file.selected and file.filtered] 146 | 147 | def update_files(self, files: List[ImageData]): 148 | for file in files: 149 | self.update_file(file) 150 | 151 | def update_file(self, file: ImageData): 152 | # Search for the file with the same ID and update it if found 153 | for i, existing_file in enumerate(self.files): 154 | if existing_file.id == file.id: 155 | self.files[i] = file 156 | break 157 | else: 158 | # The file was not found in the list, append it 159 | self.files.append(file) 160 | 161 | def update_filters(self): 162 | """ 163 | Filters a collection of files based on specified inclusion and exclusion criteria for tags and filter strings. 164 | 165 | The function filters files based on tags extracted from file captions and matches these tags against specified 166 | inclusion and exclusion criteria. The criteria can be a set of plain tags, wildcard patterns, or regex expressions. 167 | 168 | Parameters: 169 | - use_all_files (bool): Determines whether to filter from all files or only selected files. 170 | If True, filters from all files; otherwise, filters from selected files. 171 | 172 | Globals: 173 | - filter_tags_include (list): Tags required to include a file. 174 | - filter_tags_exclude (list): Tags that lead to exclusion of a file. 175 | - filter_string_include (list): Filter strings for inclusion; supports wildcards and regex. 176 | - filter_string_exclude (list): Filter strings for exclusion; supports wildcards and regex. 177 | - all_files (list): List of all files, where each file is a tuple (filename, caption). 178 | - selected_files (list): List of selected files, where each file is a tuple (filename, caption). 179 | 180 | Returns: 181 | - filtered_files (list): List of files filtered based on the specified criteria. 182 | 183 | Examples: 184 | 1. Plain Tag Matching: 185 | - To include files with a tag 'holiday', add 'holiday' to filter_tags_include. 186 | - To exclude files with a tag 'work', add 'work' to filter_tags_exclude. 187 | 188 | 2. Wildcard Patterns: 189 | - To include files with tags starting with 'trip', add 'trip*' to filter_string_include. 190 | - To exclude files with tags ending with '2023', add '*2023' to filter_string_exclude. 191 | 192 | 3. Regex Expressions: 193 | - To include files with tags that have any number followed by 'days', add '\\d+days' to filter_string_include. 194 | - To exclude files with tags formatted like dates (e.g., '2023-04-01'), add '\\d{4}-\\d{2}-\\d{2}' to filter_string_exclude. 195 | 196 | Note: 197 | - The function treats tags as case-sensitive. 198 | - Wildcard '*' matches any sequence of characters (including none). 199 | - Regex patterns should follow Python's 're' module syntax. 200 | """ 201 | 202 | def matches_pattern(pattern, string): 203 | # Convert wildcard to regex if necessary 204 | if '*' in pattern: 205 | pattern = '^' + pattern.replace('*', '.*') + '$' 206 | return re.match(pattern, string) is not None 207 | else: 208 | if " " not in pattern: 209 | parts = string.split(" ") 210 | return pattern in parts 211 | else: 212 | return pattern in string 213 | 214 | def should_include(tag, filter_tags, filter_strings): 215 | if len(filter_tags) == 0 and len(filter_strings) == 0: 216 | return True 217 | tag_match = False 218 | filter_match = False 219 | if tag in filter_tags and len(filter_tags) > 0: 220 | tag_match = True 221 | if len(filter_strings) > 0: 222 | for filter_string in filter_strings: 223 | if matches_pattern(filter_string, tag): 224 | filter_match = True 225 | break 226 | return tag_match or filter_match 227 | 228 | def should_exclude(tag, filter_tags, filter_strings): 229 | if tag in filter_tags: 230 | return True 231 | for filter_string in filter_strings: 232 | if matches_pattern(filter_string, tag): 233 | return True 234 | return False 235 | 236 | files = self.files 237 | updated_files = [] 238 | for file in files: 239 | tags = file.tags 240 | out_tags = [] 241 | for tag in tags: 242 | include = True 243 | if should_exclude(tag, self.excluded_tags, self.excluded_strings): 244 | include = False 245 | elif not should_include(tag, self.included_tags, self.included_strings): 246 | include = False 247 | if include: 248 | out_tags.append(tag) 249 | file.filtered = len(out_tags) > 0 250 | updated_files.append(file) 251 | self.files = files 252 | 253 | def all_captions(self): 254 | return [file.caption for file in self.files] 255 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from launch import run, git_clone, repo_dir 5 | 6 | name = "Smart Crop" 7 | req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") 8 | print(f"loading {name} reqs from {req_file}") 9 | run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements.", 10 | f"Couldn't install {name} requirements.") 11 | print("Cloning BLIP repo...") 12 | blip_repo = "https://github.com/pharmapsychotic/BLIP" 13 | git_clone(blip_repo, repo_dir('BLIP'), "BLIP") 14 | 15 | # git clone https://github.com/X-PLUG/mPLUG-Owl.git 16 | # cd mPLUG-Owl/mPLUG-Owl2 17 | # mplug_repo = "https://github.com/X-PLUG/mPLUG-Owl.git" 18 | # mplug_owl_path = repo_dir('mplug_owl_src') 19 | # git_clone(mplug_repo, mplug_owl_path, "mPLUG-Owl") 20 | # mplug_owl2_sub1_path = os.path.join(mplug_owl_path, "mPLUG-Owl", "mplug_owl") 21 | # mplug_owl2_sub2_path = os.path.join(mplug_owl_path, "mPLUG-Owl2", "mplug_owl2") 22 | # 23 | # sys.path.append(mplug_owl2_sub1_path) 24 | # sys.path.append(mplug_owl2_sub2_path) 25 | # 26 | -------------------------------------------------------------------------------- /interrogators/blip_interrogator.py: -------------------------------------------------------------------------------- 1 | from PIL.Image import Image 2 | 3 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator 4 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 5 | import torch 6 | 7 | from extensions.sd_smartprocess.model_download import fetch_model 8 | from extensions.sd_smartprocess.process_params import ProcessParams 9 | 10 | 11 | class BLIPInterrogator(Interrogator): 12 | _instance = None # Class variable to hold the singleton instance 13 | params = {"max_tokens": 75, "load_in_8bit": False} 14 | 15 | def __init__(self, params: ProcessParams): 16 | super().__init__(params) 17 | self.processor = None 18 | self.model = None 19 | self.current_device = None 20 | self.load_8bit = params.load_in_8bit 21 | 22 | def __new__(cls, params: ProcessParams): 23 | if cls._instance is None: 24 | cls._instance = super(BLIPInterrogator, cls).__new__(cls) 25 | try: 26 | cls._instance._init(params) 27 | except Exception as e: 28 | cls._instance = None 29 | raise e 30 | cls.initial_prompt = params.blip_initial_prompt 31 | cls._load_8bit = params.load_in_8bit 32 | return cls._instance 33 | 34 | def _init(self, params: ProcessParams): 35 | super().__init__(params) 36 | self.model = None 37 | self.processor = None 38 | self.load_8bit = params.load_in_8bit 39 | 40 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 41 | self.load() 42 | self.params = params 43 | max_tokens = params.max_tokens 44 | inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16) 45 | generated_ids = self.model.generate(**inputs, max_new_tokens=max_tokens) 46 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 47 | if unload: 48 | self.unload() 49 | return generated_text 50 | 51 | def unload(self): 52 | if self.current_device != "cpu": 53 | try: 54 | self.model.to("cpu") 55 | self.current_device = "cpu" 56 | except: 57 | pass 58 | 59 | def load(self): 60 | try: 61 | if self.model is None: 62 | model_path = fetch_model('Salesforce/blip2-opt-6.7b', "blip2") 63 | self.processor = AutoProcessor.from_pretrained(model_path) 64 | self.model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, load_in_8bit=self.load_8bit) 65 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 66 | if self.device != self.current_device: 67 | self.model.to(self.device) 68 | self.current_device = self.device 69 | except: 70 | pass 71 | -------------------------------------------------------------------------------- /interrogators/booru_interrogator.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import torch 5 | from PIL.Image import Image 6 | 7 | import modules.deepbooru 8 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator, re_special 9 | from extensions.sd_smartprocess.process_params import ProcessParams 10 | from modules import images, devices 11 | 12 | 13 | class BooruInterrogator(Interrogator): 14 | params = {"min_score": 0.75} 15 | 16 | def __init__(self, params: ProcessParams) -> None: 17 | super().__init__(params) 18 | self.tags = None 19 | self.booru = modules.deepbooru.DeepDanbooru() 20 | self.model = self.booru.model 21 | 22 | def unload(self): 23 | self.booru.stop() 24 | 25 | def load(self): 26 | self.booru.start() 27 | pass 28 | 29 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 30 | self.load() 31 | self.params = params 32 | min_score = params.booru_min_score 33 | pic = images.resize_image(2, image, 512, 512) 34 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 35 | 36 | with torch.no_grad(), devices.autocast(): 37 | # Move the model to the same device as the input tensor 38 | self.model.to(devices.device) 39 | # Convert input to the correct type (half-precision if needed) 40 | x = torch.from_numpy(a).to(devices.device).type_as(self.model.n_Conv_0.weight) 41 | # Forward pass through the model 42 | y = self.model(x)[0].detach().cpu().numpy() 43 | 44 | probability_dict = {} 45 | 46 | for tag, probability in zip(self.model.tags, y): 47 | if tag.startswith("rating:"): 48 | continue 49 | 50 | probability_dict[tag] = probability 51 | 52 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])] 53 | 54 | output = {} 55 | for tag in tags: 56 | probability = probability_dict[tag] 57 | tag_outformat = tag 58 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) 59 | output[tag_outformat] = probability 60 | out_tags = [] 61 | for tag in sorted(output, key=output.get, reverse=True): 62 | if output[tag] >= min_score: 63 | # print(f"DBTag {tag} score is {tags[tag]}") 64 | out_tags.append(tag) 65 | output = ", ".join(out_tags) 66 | if unload: 67 | self.unload() 68 | return output 69 | -------------------------------------------------------------------------------- /interrogators/data/mediums.txt: -------------------------------------------------------------------------------- 1 | a 3D render 2 | a black and white photo 3 | a bronze sculpture 4 | a cartoon 5 | a cave painting 6 | a character portrait 7 | a charcoal drawing 8 | a child's drawing 9 | a color pencil sketch 10 | a colorized photo 11 | a comic book panel 12 | a computer rendering 13 | a cross stitch 14 | a cubist painting 15 | a detailed drawing 16 | a detailed matte painting 17 | a detailed painting 18 | a diagram 19 | a digital painting 20 | a digital rendering 21 | a drawing 22 | a fine art painting 23 | a flemish Baroque 24 | a gouache 25 | a hologram 26 | a hyperrealistic painting 27 | a jigsaw puzzle 28 | a low poly render 29 | a macro photograph 30 | a manga drawing 31 | a marble sculpture 32 | a matte painting 33 | a microscopic photo 34 | a mid-nineteenth century engraving 35 | a minimalist painting 36 | a mosaic 37 | a painting 38 | a pastel 39 | a pencil sketch 40 | a photo 41 | a photocopy 42 | a photorealistic painting 43 | a picture 44 | a pointillism painting 45 | a polaroid photo 46 | a pop art painting 47 | a portrait 48 | a poster 49 | a raytraced image 50 | a renaissance painting 51 | a screenprint 52 | a screenshot 53 | a silk screen 54 | a sketch 55 | a statue 56 | a still life 57 | a stipple 58 | a stock photo 59 | a storybook illustration 60 | a surrealist painting 61 | a surrealist sculpture 62 | a tattoo 63 | a tilt shift photo 64 | a watercolor painting 65 | a wireframe diagram 66 | a woodcut 67 | an abstract drawing 68 | an abstract painting 69 | an abstract sculpture 70 | an acrylic painting 71 | an airbrush painting 72 | an album cover 73 | an ambient occlusion render 74 | an anime drawing 75 | an art deco painting 76 | an art deco sculpture 77 | an engraving 78 | an etching 79 | an illustration of 80 | an impressionist painting 81 | an ink drawing 82 | an oil on canvas painting 83 | an oil painting 84 | an ultrafine detailed painting 85 | chalk art 86 | computer graphics 87 | concept art 88 | cyberpunk art 89 | digital art 90 | egyptian art 91 | graffiti art 92 | lineart 93 | pixel art 94 | poster art 95 | vector art -------------------------------------------------------------------------------- /interrogators/data/movements.txt: -------------------------------------------------------------------------------- 1 | abstract art 2 | abstract expressionism 3 | abstract illusionism 4 | academic art 5 | action painting 6 | aestheticism 7 | afrofuturism 8 | altermodern 9 | american barbizon school 10 | american impressionism 11 | american realism 12 | american romanticism 13 | american scene painting 14 | analytical art 15 | antipodeans 16 | arabesque 17 | arbeitsrat für kunst 18 | art & language 19 | art brut 20 | art deco 21 | art informel 22 | art nouveau 23 | art photography 24 | arte povera 25 | arts and crafts movement 26 | ascii art 27 | ashcan school 28 | assemblage 29 | australian tonalism 30 | auto-destructive art 31 | barbizon school 32 | baroque 33 | bauhaus 34 | bengal school of art 35 | berlin secession 36 | black arts movement 37 | brutalism 38 | classical realism 39 | cloisonnism 40 | cobra 41 | color field 42 | computer art 43 | conceptual art 44 | concrete art 45 | constructivism 46 | context art 47 | crayon art 48 | crystal cubism 49 | cubism 50 | cubo-futurism 51 | cynical realism 52 | dada 53 | danube school 54 | dau-al-set 55 | de stijl 56 | deconstructivism 57 | digital art 58 | ecological art 59 | environmental art 60 | excessivism 61 | expressionism 62 | fantastic realism 63 | fantasy art 64 | fauvism 65 | feminist art 66 | figuration libre 67 | figurative art 68 | figurativism 69 | fine art 70 | fluxus 71 | folk art 72 | funk art 73 | furry art 74 | futurism 75 | generative art 76 | geometric abstract art 77 | german romanticism 78 | gothic art 79 | graffiti 80 | gutai group 81 | happening 82 | harlem renaissance 83 | heidelberg school 84 | holography 85 | hudson river school 86 | hurufiyya 87 | hypermodernism 88 | hyperrealism 89 | impressionism 90 | incoherents 91 | institutional critique 92 | interactive art 93 | international gothic 94 | international typographic style 95 | kinetic art 96 | kinetic pointillism 97 | kitsch movement 98 | land art 99 | les automatistes 100 | les nabis 101 | letterism 102 | light and space 103 | lowbrow 104 | lyco art 105 | lyrical abstraction 106 | magic realism 107 | magical realism 108 | mail art 109 | mannerism 110 | massurrealism 111 | maximalism 112 | metaphysical painting 113 | mingei 114 | minimalism 115 | modern european ink painting 116 | modernism 117 | modular constructivism 118 | naive art 119 | naturalism 120 | neo-dada 121 | neo-expressionism 122 | neo-fauvism 123 | neo-figurative 124 | neo-primitivism 125 | neo-romanticism 126 | neoclassicism 127 | neogeo 128 | neoism 129 | neoplasticism 130 | net art 131 | new objectivity 132 | new sculpture 133 | northwest school 134 | nuclear art 135 | objective abstraction 136 | op art 137 | optical illusion 138 | orphism 139 | panfuturism 140 | paris school 141 | photorealism 142 | pixel art 143 | plasticien 144 | plein air 145 | pointillism 146 | pop art 147 | pop surrealism 148 | post-impressionism 149 | postminimalism 150 | pre-raphaelitism 151 | precisionism 152 | primitivism 153 | private press 154 | process art 155 | psychedelic art 156 | purism 157 | qajar art 158 | quito school 159 | rasquache 160 | rayonism 161 | realism 162 | regionalism 163 | remodernism 164 | renaissance 165 | retrofuturism 166 | rococo 167 | romanesque 168 | romanticism 169 | samikshavad 170 | serial art 171 | shin hanga 172 | shock art 173 | socialist realism 174 | sots art 175 | space art 176 | street art 177 | stuckism 178 | sumatraism 179 | superflat 180 | suprematism 181 | surrealism 182 | symbolism 183 | synchromism 184 | synthetism 185 | sōsaku hanga 186 | tachisme 187 | temporary art 188 | tonalism 189 | toyism 190 | transgressive art 191 | ukiyo-e 192 | underground comix 193 | unilalianism 194 | vancouver school 195 | vanitas 196 | verdadism 197 | video art 198 | viennese actionism 199 | visual art 200 | vorticism -------------------------------------------------------------------------------- /interrogators/idefics2_interrogator.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | from datetime import datetime 4 | 5 | import torch 6 | from PIL import Image 7 | from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig 8 | 9 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator 10 | from extensions.sd_smartprocess.model_download import fetch_model 11 | from extensions.sd_smartprocess.mplug_owl2.mm_utils import get_model_name_from_path 12 | from extensions.sd_smartprocess.process_params import ProcessParams 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | NO_CAP_PROMPT = """ 18 | Generate a concise caption describing the key elements and context of the image in one sentence, 19 | focusing on the medium, subject, style, clothing, pose, action, and location. Ensure the sentence is accurate and devoid 20 | of assumptions, prose, etc. Keep it direct and relevant to the image. 21 | 22 | Follow the caption with a list of specific tags (keywords) detailing smaller key elements like clothing, poses, 23 | actions, and other notable features. 24 | """ 25 | 26 | EX_CAP_PROMPT = """ 27 | Here is a caption consisting of a sentence and a list of tags (keywords) that describe the image. 28 | 29 | Refine the provided caption to more accurately and vividly capture the essence and key details visible in the image, 30 | focusing on encapsulating the medium, subject, style, clothing, pose, action, and location in one sentence. 31 | 32 | Update the accompanying tags to reflect only the elements present in the image, ensuring precision and relevance. 33 | Avoid adding new information not supported by the existing caption or the image. 34 | """ 35 | 36 | 37 | class Idefics2Interrogator(Interrogator): 38 | model = None 39 | processor = None 40 | params = {"max_tokens": 75, "load_in_4bit": True, "replace_blip_caption": True, "idefics2_prompt": "Describe this image in one detailed sentence, include the subject, location, style, and type of image."} 41 | 42 | def __init__(self, params: ProcessParams): 43 | super().__init__(params) 44 | logger.debug("Initializing LLM model...") 45 | model_path = fetch_model("HuggingFaceM4/idefics2-8b", "llm") 46 | model_name = get_model_name_from_path(model_path) 47 | self.load_4bit = params.load_in_4bit 48 | self.prompt = params.interrogation_prompt 49 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 50 | self.model = None 51 | self.tokenizer = None 52 | self.image_processor = None 53 | self.context_len = None 54 | self.current_device = None 55 | logger.debug("Initialized LLM model.") 56 | 57 | def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str: 58 | self.load() 59 | self.prompt = NO_CAP_PROMPT 60 | if params is not None: 61 | if params.new_caption != "": 62 | existing_caption_txt = params.new_caption 63 | self.prompt = f"{EX_CAP_PROMPT}: {existing_caption_txt}" 64 | logger.debug(f"Existing caption query: {self.prompt}") 65 | 66 | # Create inputs 67 | messages = [ 68 | { 69 | "role": "user", 70 | "content": [ 71 | {"type": "image"}, 72 | {"type": "text", "text": self.prompt}, 73 | ] 74 | } 75 | ] 76 | 77 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) 78 | inputs = self.processor(text=prompt, images=[image], return_tensors="pt") 79 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 80 | 81 | # Generate 82 | generated_ids = self.model.generate(**inputs, max_new_tokens=500) 83 | caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True) 84 | caption = caption[0].strip() 85 | # Find "Assistant: " in the string and return everything after it 86 | caption = caption[caption.find("Assistant: ") + len("Assistant: "):].strip() 87 | if params.txt_action != "include": 88 | caption = caption.replace(",", "").replace(".", "").replace("?", "").replace("!", "").strip() 89 | if unload: 90 | self.unload() 91 | return caption 92 | 93 | def _to_cpu(self): 94 | if self.load_4bit: 95 | print("Model is loaded in 8bit, can't move to CPU.") 96 | return 97 | from extensions.sd_smartprocess.smartprocess import vram_usage 98 | used, total = vram_usage() 99 | print(f"VRAM: {used}/{total}") 100 | free = total - used 101 | # If we have over 16GB of VRAM, we can use the GPU 102 | if free > 16: 103 | print("VRAM is over 16GB, moving to GPU") 104 | self._to_gpu() 105 | return 106 | print("Moving to CPU") 107 | if self.current_device != "cpu": 108 | self.model.to('cpu') 109 | self.current_device = "cpu" 110 | if torch.cuda.is_available(): 111 | torch.cuda.empty_cache() 112 | gc.collect() 113 | 114 | def _to_gpu(self): 115 | if self.model is None: 116 | self.load() 117 | return 118 | if self.load_4bit: 119 | print("Model is loaded in 4bit, can't move to GPU.") 120 | return 121 | if self.current_device != "cuda" and torch.cuda.is_available(): 122 | print("Moving to GPU") 123 | time = datetime.now() 124 | self.model.to(self.device) 125 | print(f"Model to GPU: {datetime.now() - time}") 126 | self.current_device = "cuda" 127 | 128 | def unload(self): 129 | if self.model is not None: 130 | print("Unloading Idefics2 model") 131 | self._to_cpu() 132 | 133 | def load(self): 134 | if self.model is None: 135 | model_path = fetch_model("HuggingFaceM4/idefics2-8b", "llm") 136 | # model_name = get_model_name_from_path(model_path) 137 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 138 | self.current_device = self.device 139 | print(f"Loading processor on {self.device}") 140 | self.processor = AutoProcessor.from_pretrained(model_path) 141 | if self.load_4bit: 142 | print(f"Loading model.") 143 | quantization_config = BitsAndBytesConfig( 144 | load_in_4bit=True, 145 | bnb_4bit_quant_type="nf4", 146 | bnb_4bit_use_double_quant=True, 147 | bnb_4bit_compute_dtype=torch.float16 148 | ) 149 | self.model = AutoModelForVision2Seq.from_pretrained( 150 | model_path, 151 | torch_dtype=torch.float16, 152 | quantization_config=quantization_config 153 | ) 154 | else: 155 | self.model = AutoModelForVision2Seq.from_pretrained( 156 | model_path, 157 | torch_dtype=torch.float16 158 | ).to(self.device) 159 | print(f"Model loaded on {self.device}") 160 | self._to_gpu() 161 | -------------------------------------------------------------------------------- /interrogators/interrogator.py: -------------------------------------------------------------------------------- 1 | # Borrowed from https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/interrogator.py 2 | import importlib 3 | import pkgutil 4 | import re 5 | import sys 6 | from abc import abstractmethod 7 | 8 | import torch 9 | from PIL import Image 10 | 11 | import extensions.sd_smartprocess.interrogators as interrogators 12 | from extensions.sd_smartprocess.process_params import ProcessParams 13 | 14 | 15 | @abstractmethod 16 | class Interrogator: 17 | def __init__(self, params: ProcessParams) -> None: 18 | self.params = params 19 | self.model = None 20 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 21 | registry = InterrogatorRegistry() 22 | registry.register(self) 23 | 24 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 25 | raise NotImplementedError 26 | 27 | def unload(self): 28 | if self.model: 29 | try: 30 | self.model = self.model.to("cpu") 31 | except: 32 | pass 33 | 34 | def load(self): 35 | if self.model: 36 | try: 37 | self.model = self.model.to(self.device) 38 | except: 39 | pass 40 | 41 | 42 | re_special = re.compile(r'([\\()])') 43 | 44 | 45 | class InterrogatorRegistry: 46 | _instance = None # Class variable to hold the singleton instance 47 | 48 | def __new__(cls): 49 | if cls._instance is None: 50 | cls._instance = super(InterrogatorRegistry, cls).__new__(cls) 51 | cls._instance._init() 52 | return cls._instance 53 | 54 | def _init(self): 55 | self.interrogators = {} 56 | 57 | def register(self, interrogator: Interrogator): 58 | self.interrogators[interrogator.__class__.__name__] = interrogator 59 | 60 | def get_interrogator(self, interrogator_name: str) -> Interrogator: 61 | return self.interrogators[interrogator_name] 62 | 63 | def get_interrogators(self): 64 | return self.interrogators 65 | 66 | def unload(self): 67 | for interrogator in self.interrogators.values(): 68 | interrogator.unload() 69 | 70 | def load(self): 71 | for interrogator in self.interrogators.values(): 72 | interrogator.load() 73 | 74 | @staticmethod 75 | def list_interrogators(return_cls=False): 76 | # First, dynamically import all modules in the package to ensure all subclasses are loaded 77 | package = interrogators 78 | for importer, modname, ispkg in pkgutil.iter_modules(package.__path__, package.__name__ + '.'): 79 | print(f"Importing {modname}") 80 | try: 81 | importlib.import_module(modname) 82 | except Exception as e: 83 | print(f"Error importing {modname}: {e}") 84 | 85 | # Now, enumerate subclasses of Interrogator to access their class-level 'params' 86 | interrogator_dict = {} 87 | subclass_names = [cls.__name__ for cls in Interrogator.__subclasses__()] 88 | print(f"Subclass names: {subclass_names}") 89 | for cls in Interrogator.__subclasses__(): 90 | if return_cls: 91 | interrogator_dict[cls.__name__] = cls 92 | continue 93 | params = getattr(cls, "params", None) 94 | if params is not None: # Ensure 'params' is defined at the class level 95 | interrogator_dict[cls.__name__] = params 96 | else: 97 | interrogator_dict[cls.__name__] = {} 98 | 99 | return interrogator_dict 100 | 101 | -------------------------------------------------------------------------------- /interrogators/llava2_interrogator.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | from datetime import datetime 4 | 5 | import torch 6 | from PIL import Image 7 | from transformers import TextStreamer 8 | 9 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator 10 | from extensions.sd_smartprocess.model_download import fetch_model 11 | from extensions.sd_smartprocess.mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 12 | from extensions.sd_smartprocess.mplug_owl2.conversation import conv_templates 13 | from extensions.sd_smartprocess.mplug_owl2.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token, \ 14 | process_images, \ 15 | get_model_name_from_path 16 | from extensions.sd_smartprocess.mplug_owl2.model.builder import load_pretrained_model 17 | from extensions.sd_smartprocess.process_params import ProcessParams 18 | 19 | # This is basically broken until we can update transformers in AUTO past the current version supported 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | NO_CAP_PROMPT = """ 24 | Generate a concise caption describing the key elements and context of the image in one sentence, 25 | focusing on the medium, subject, style, clothing, pose, action, and location. Ensure the sentence is accurate and devoid 26 | of assumptions, prose, etc. Keep it direct and relevant to the image. 27 | 28 | Follow the caption with a list of specific tags (keywords) detailing smaller key elements like clothing, poses, 29 | actions, and other notable features. 30 | """ 31 | 32 | EX_CAP_PROMPT = """ 33 | Here is a caption consisting of a sentence and a list of tags (keywords) that describe the image. 34 | 35 | Refine the provided caption to more accurately and vividly capture the essence and key details visible in the image, 36 | focusing on encapsulating the medium, subject, style, clothing, pose, action, and location in one sentence. 37 | 38 | Update the accompanying tags to reflect only the elements present in the image, ensuring precision and relevance. 39 | Avoid adding new information not supported by the existing caption or the image. 40 | """ 41 | 42 | 43 | class LLAVA2Interrogator(Interrogator): 44 | model = None 45 | processor = None 46 | params = {"max_tokens": 75, "load_in_8bit": False, "replace_blip_caption": True} 47 | 48 | def __init__(self, params: ProcessParams): 49 | super().__init__(params) 50 | logger.debug("Initializing LLM model...") 51 | model_path = fetch_model('MAGAer13/mplug-owl2-llama2-7b', "llm") 52 | model_name = get_model_name_from_path(model_path) 53 | self.load_8bit = params.load_in_8bit 54 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 55 | self.model = None 56 | self.tokenizer = None 57 | self.image_processor = None 58 | self.context_len = None 59 | self.current_device = None 60 | logger.debug("Initialized LLM model.") 61 | 62 | def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str: 63 | self.load() 64 | 65 | query = NO_CAP_PROMPT 66 | if params is not None: 67 | if params.new_caption != "": 68 | existing_caption_txt = params.new_caption 69 | query = f"{EX_CAP_PROMPT}: {existing_caption_txt}" 70 | logger.debug(f"Existing caption query: {query}") 71 | 72 | conv = conv_templates["mplug_owl2"].copy() 73 | 74 | max_edge = max(image.size) 75 | image = image.resize((max_edge, max_edge)) 76 | image_tensor = process_images([image], self.image_processor) 77 | image_tensor = image_tensor.to(self.model.device, dtype=torch.float16) 78 | 79 | inp = query + " " + DEFAULT_IMAGE_TOKEN 80 | conv.append_message(conv.roles[0], inp) 81 | conv.append_message(conv.roles[1], None) 82 | prompt = conv.get_prompt() 83 | 84 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze( 85 | 0).to( 86 | self.model.device) 87 | stop_str = conv.sep2 88 | keywords = [stop_str] 89 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) 90 | streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) 91 | 92 | temperature = 0.7 93 | max_new_tokens = 512 94 | 95 | with torch.inference_mode(): 96 | output_ids = self.model.generate( 97 | input_ids, 98 | images=image_tensor, 99 | do_sample=True, 100 | temperature=temperature, 101 | max_new_tokens=max_new_tokens, 102 | streamer=streamer, 103 | use_cache=True, 104 | stopping_criteria=[stopping_criteria]) 105 | 106 | caption = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 107 | if params.txt_action != "include": 108 | caption = caption.replace(",", "").replace(".", "").replace("?", "").replace("!", "").strip() 109 | if unload: 110 | self.unload() 111 | return caption 112 | 113 | def _to_cpu(self): 114 | if self.load_8bit: 115 | print("Model is loaded in 8bit, can't move to CPU.") 116 | return 117 | from extensions.sd_smartprocess.smartprocess import vram_usage 118 | used, total = vram_usage() 119 | print(f"VRAM: {used}/{total}") 120 | free = total - used 121 | # If we have over 16GB of VRAM, we can use the GPU 122 | if free > 16: 123 | print("VRAM is over 16GB, moving to GPU") 124 | self._to_gpu() 125 | return 126 | print("Moving to CPU") 127 | if self.current_device != "cpu": 128 | self.model.to('cpu') 129 | self.current_device = "cpu" 130 | if torch.cuda.is_available(): 131 | torch.cuda.empty_cache() 132 | gc.collect() 133 | 134 | def _to_gpu(self): 135 | if self.model is None: 136 | self.load() 137 | return 138 | if self.current_device != "cuda" and torch.cuda.is_available(): 139 | print("Moving to GPU") 140 | time = datetime.now() 141 | self.model.to(self.device) 142 | print(f"Model to GPU: {datetime.now() - time}") 143 | self.current_device = "cuda" 144 | # self.image_processor.to(self.device) 145 | # self.tokenizer.to(self.device) 146 | 147 | def unload(self): 148 | if self.model is not None: 149 | print("Unloading LLAVA2 model") 150 | self._to_cpu() 151 | 152 | def load(self): 153 | if self.model is None: 154 | model_path = fetch_model('MAGAer13/mplug-owl2-llama2-7b', "llm") 155 | model_name = get_model_name_from_path(model_path) 156 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 157 | self.current_device = self.device 158 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, 159 | model_name, 160 | load_8bit=self.load_8bit, 161 | load_4bit=False, 162 | device="cuda") 163 | 164 | self._to_gpu() 165 | -------------------------------------------------------------------------------- /interrogators/moondream_interrogator.py: -------------------------------------------------------------------------------- 1 | from PIL.Image import Image 2 | 3 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator 4 | from model_download import fetch_model 5 | from process_params import ProcessParams 6 | from torch import float16 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | 9 | 10 | class MoonDreamInterrogator(Interrogator): 11 | params = {"interrogation_prompt": "Describe this image in one detailed sentence."} # Class-level attribute 12 | 13 | def __init__(self, params: ProcessParams): 14 | super().__init__(params) 15 | self.interrogation_prompt = params.interrogation_prompt 16 | self.model = None 17 | self.tokenizer = None 18 | print("Initializing Moondream...") 19 | 20 | def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str: 21 | self.load() 22 | enc_image = self.model.encode_image(image) 23 | caption = self.model.answer_question(enc_image, "Describe this image in one detailed sentence.", self.tokenizer) 24 | caption = caption.replace(",", "").replace(".", "").replace("?", "").replace("!", "").strip() 25 | if unload: 26 | self.unload() 27 | return caption 28 | 29 | def _to_gpu(self): 30 | self.model = self.model.to("cuda") 31 | 32 | def _to_cpu(self): 33 | self.model = self.model.to("cpu") 34 | 35 | def load(self): 36 | if self.model is None: 37 | model_id = "vikhyatk/moondream2" 38 | model_path = fetch_model(model_id, "llm") 39 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 40 | 41 | self.model = AutoModelForCausalLM.from_pretrained( 42 | model_path, trust_remote_code=True, 43 | torch_dtype=float16).to("cuda") 44 | -------------------------------------------------------------------------------- /interrogators/wolf_interrogator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import PIL 4 | import cv2 5 | import huggingface_hub 6 | import numpy as np 7 | import pandas as pd 8 | from PIL.Image import Image 9 | from onnxruntime import InferenceSession 10 | 11 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator 12 | from extensions.sd_smartprocess.model_download import fetch_model 13 | from extensions.sd_smartprocess.process_params import ProcessParams 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" 18 | MODEL_FILENAME = "model.onnx" 19 | LABEL_FILENAME = "selected_tags.csv" 20 | WOLF_PARAMS = { 21 | "group": "WOLF", 22 | "threshold": 0.75, 23 | "char_threshold": 0.75 24 | } 25 | 26 | 27 | class MoatInterrogator(Interrogator): 28 | params = WOLF_PARAMS 29 | 30 | def __init__(self, params: ProcessParams) -> None: 31 | super().__init__(params) 32 | self._setup() 33 | 34 | def _setup(self): 35 | model_path = "SmilingWolf/wd-v1-4-moat-tagger-v2" 36 | self.model = load_model(model_path, self.device) 37 | 38 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 39 | threshold = params.threshold 40 | char_threshold = params.char_threshold 41 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold) 42 | return a 43 | 44 | 45 | class SwinInterrogator(Interrogator): 46 | params = WOLF_PARAMS 47 | 48 | def __init__(self, params: ProcessParams) -> None: 49 | super().__init__(params) 50 | self._setup() 51 | 52 | def _setup(self): 53 | model_path = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" 54 | self.model = load_model(model_path, self.device) 55 | 56 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 57 | threshold = params.threshold 58 | char_threshold = params.char_threshold 59 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold) 60 | return a 61 | 62 | 63 | class ConvInterrogator(Interrogator): 64 | params = WOLF_PARAMS 65 | 66 | def __init__(self, params: ProcessParams) -> None: 67 | super().__init__(params) 68 | self._setup() 69 | 70 | def _setup(self): 71 | model_path = "SmilingWolf/wd-v1-4-convnext-tagger-v2" 72 | self.model = load_model(model_path, self.device) 73 | 74 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 75 | threshold = params.threshold 76 | char_threshold = params.char_threshold 77 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold) 78 | return a 79 | 80 | 81 | class Conv2Interrogator(Interrogator): 82 | params = WOLF_PARAMS 83 | 84 | def __init__(self, params: ProcessParams) -> None: 85 | super().__init__(params) 86 | self._setup() 87 | 88 | def _setup(self): 89 | model_path = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" 90 | self.model = load_model(model_path, self.device) 91 | 92 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 93 | threshold = params.threshold 94 | char_threshold = params.char_threshold 95 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold) 96 | return a 97 | 98 | 99 | class VitInterrogator(Interrogator): 100 | params = WOLF_PARAMS 101 | 102 | def __init__(self, params: ProcessParams) -> None: 103 | super().__init__(params) 104 | self._setup() 105 | 106 | def _setup(self): 107 | model_path = "SmilingWolf/wd-v1-4-vit-tagger-v2" 108 | self.model = load_model(model_path, self.device) 109 | 110 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str: 111 | threshold = params.threshold 112 | char_threshold = params.char_threshold 113 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold) 114 | return a 115 | 116 | 117 | def predict( 118 | image: Image, 119 | model, 120 | general_threshold: float = 0.5, 121 | character_threshold: float = 0.5 122 | ): 123 | tag_names, rating_indexes, general_indexes, character_indexes = load_labels() 124 | 125 | _, height, width, _ = model.get_inputs()[0].shape 126 | 127 | # Alpha to white 128 | image = image.convert("RGBA") 129 | new_image = PIL.Image.new("RGBA", image.size, "WHITE") 130 | new_image.paste(image, mask=image) 131 | image = new_image.convert("RGB") 132 | image = np.asarray(image) 133 | 134 | # PIL RGB to OpenCV BGR 135 | image = image[:, :, ::-1] 136 | 137 | image = make_square(image, height) 138 | image = smart_resize(image, height) 139 | image = image.astype(np.float32) 140 | image = np.expand_dims(image, 0) 141 | 142 | input_name = model.get_inputs()[0].name 143 | label_name = model.get_outputs()[0].name 144 | probs = model.run([label_name], {input_name: image})[0] 145 | 146 | labels = list(zip(tag_names, probs[0].astype(float))) 147 | 148 | # First 4 labels are actually ratings: pick one with argmax 149 | ratings_names = [labels[i] for i in rating_indexes] 150 | rating = dict(ratings_names) 151 | 152 | # Then we have general tags: pick any where prediction confidence > threshold 153 | general_names = [labels[i] for i in general_indexes] 154 | general_res = [x for x in general_names if x[1] > general_threshold] 155 | general_res = dict(general_res) 156 | 157 | # Everything else is characters: pick any where prediction confidence > threshold 158 | character_names = [labels[i] for i in character_indexes] 159 | character_res = [x for x in character_names if x[1] > character_threshold] 160 | character_res = dict(character_res) 161 | 162 | b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True)) 163 | a = ( 164 | ", ".join(list(b.keys())) 165 | .replace("_", " ") 166 | .replace("(", "\(") 167 | .replace(")", "\)") 168 | ) 169 | c = ", ".join(list(b.keys())) 170 | 171 | return a, c, rating, character_res, general_res 172 | 173 | 174 | def load_model(model_path, device): 175 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 176 | if device == "cpu": 177 | providers.pop(0) 178 | path = fetch_model(model_path, "wolf", True) 179 | model = InferenceSession(path, providers=providers) 180 | return model 181 | 182 | 183 | def smart_imread(img, flag=cv2.IMREAD_UNCHANGED): 184 | if img.endswith(".gif"): 185 | img = PIL.Image.open(img) 186 | img = img.convert("RGB") 187 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 188 | else: 189 | img = cv2.imread(img, flag) 190 | return img 191 | 192 | 193 | def smart_24bit(img): 194 | if img.dtype is np.dtype(np.uint16): 195 | img = (img / 257).astype(np.uint8) 196 | 197 | if len(img.shape) == 2: 198 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 199 | elif img.shape[2] == 4: 200 | trans_mask = img[:, :, 3] == 0 201 | img[trans_mask] = [255, 255, 255, 255] 202 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) 203 | return img 204 | 205 | 206 | def make_square(img, target_size): 207 | old_size = img.shape[:2] 208 | desired_size = max(old_size) 209 | desired_size = max(desired_size, target_size) 210 | 211 | delta_w = desired_size - old_size[1] 212 | delta_h = desired_size - old_size[0] 213 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 214 | left, right = delta_w // 2, delta_w - (delta_w // 2) 215 | 216 | color = [255, 255, 255] 217 | new_im = cv2.copyMakeBorder( 218 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color 219 | ) 220 | return new_im 221 | 222 | 223 | def smart_resize(img, size): 224 | # Assumes the image has already gone through make_square 225 | if img.shape[0] > size: 226 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 227 | elif img.shape[0] < size: 228 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) 229 | return img 230 | 231 | 232 | def load_labels() -> list[str]: 233 | path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME) 234 | df = pd.read_csv(path) 235 | tag_names = df["name"].tolist() 236 | rating_indexes = list(np.where(df["category"] == 9)[0]) 237 | general_indexes = list(np.where(df["category"] == 0)[0]) 238 | character_indexes = list(np.where(df["category"] == 4)[0]) 239 | return tag_names, rating_indexes, general_indexes, character_indexes 240 | -------------------------------------------------------------------------------- /javascript/smart_process.js: -------------------------------------------------------------------------------- 1 | function start_smart_process() { 2 | let progress = gradioApp().getElementById("sp_progress"); 3 | let gallery = gradioApp().getElementById("sp_gallery"); 4 | console.log("Requesting progress:", progress, gallery); 5 | requestProgress('sp', progress, gallery, atEnd, function(progress){}); 6 | gradioApp().querySelector('#sp_error').innerHTML = ''; 7 | return args_to_array(arguments); 8 | } 9 | 10 | function atEnd() { 11 | 12 | } 13 | -------------------------------------------------------------------------------- /model_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from huggingface_hub import snapshot_download 5 | 6 | from modules.paths_internal import models_path 7 | from modules.safe import unsafe_torch_load, load 8 | 9 | 10 | def fetch_model(model_repo, model_type, single_file=False): 11 | model_dir = os.path.join(models_path, model_type) 12 | dest_dir = os.path.join(model_dir, model_repo.split("/")[1]) 13 | if not os.path.exists(dest_dir): 14 | os.makedirs(dest_dir, exist_ok=True) 15 | dest_dir = snapshot_download(model_repo, repo_type="model", local_dir=dest_dir, local_dir_use_symlinks=False) 16 | if single_file: 17 | dest_dir = os.path.join(dest_dir, "model.onnx") 18 | return dest_dir 19 | 20 | 21 | disable_safe_unpickle_count = 0 22 | 23 | 24 | def disable_safe_unpickle(): 25 | global disable_safe_unpickle_count 26 | try: 27 | from modules import shared as auto_shared 28 | if not auto_shared.cmd_opts.disable_safe_unpickle: 29 | auto_shared.cmd_opts.disable_safe_unpickle = True 30 | torch.load = unsafe_torch_load 31 | disable_safe_unpickle_count += 1 32 | except: 33 | pass 34 | 35 | 36 | def enable_safe_unpickle(): 37 | global disable_safe_unpickle_count 38 | try: 39 | from modules import shared as auto_shared 40 | if disable_safe_unpickle_count > 0: 41 | disable_safe_unpickle_count -= 1 42 | if disable_safe_unpickle_count == 0 and auto_shared.cmd_opts.disable_safe_unpickle: 43 | auto_shared.cmd_opts.disable_safe_unpickle = False 44 | torch.load = load 45 | except: 46 | pass 47 | -------------------------------------------------------------------------------- /mplug_owl2/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MPLUGOwl2LlamaForCausalLM -------------------------------------------------------------------------------- /mplug_owl2/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "./demo_logs" 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "<|image|>" 10 | -------------------------------------------------------------------------------- /mplug_owl2/evaluate/EVALUATION.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | ## Important 3 | We use `batch_size = 1` for evaluation due to the batch inference issue in LLaMA if you do not suffer performance drop. (Left padding would cause the wrong positional embedding.) 4 | 5 | ## Dependencies 6 | 7 | ```bash 8 | pip install pycocoevalcap tqdm 9 | ``` 10 | 11 | ## Image Caption 12 | 13 | ### [Flickr30K](https://bryanplummer.com/Flickr30kEntities/) 14 | 15 |
16 | Data Preparation 17 | 18 | ```bash 19 | mkdir -p data/flickr && cd data/flickr 20 | 21 | # download images from https://bryanplummer.com/Flickr30kEntities/ 22 | 23 | # karpathy split annotations can be downloaded from https://cs.stanford.edu/people/karpathy/deepimagesent/ 24 | 25 | # download converted files 26 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/flickr30k/flickr30k_karpathy_test.json 27 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/flickr30k/flickr30k_karpathy_train.json 28 | 29 | cd ../.. 30 | ``` 31 | 32 |
33 | 34 |
35 | Evaluate 36 | 37 | ```bash 38 | ds="flickr" 39 | checkpoint=/PATH/TO/CHECKPOINT 40 | python -m torch.distributed.launch --use-env \ 41 | --nproc_per_node ${NPROC_PER_NODE:-8} \ 42 | --nnodes ${WORLD_SIZE:-1} \ 43 | --node_rank ${RANK:-0} \ 44 | --master_addr ${MASTER_ADDR:-127.0.0.1} \ 45 | --master_port ${MASTER_PORT:-12345} \ 46 | evaluate_caption.py \ 47 | --checkpoint $checkpoint \ 48 | --dataset $ds \ 49 | --batch-size 8 \ 50 | --num-workers 2 51 | ``` 52 |
53 | 54 | 55 |
56 | Evaluate 57 | 58 | ```bash 59 | ds="nocaps" 60 | checkpoint=/PATH/TO/CHECKPOINT 61 | python -m torch.distributed.launch --use-env \ 62 | --nproc_per_node ${NPROC_PER_NODE:-8} \ 63 | --nnodes ${WORLD_SIZE:-1} \ 64 | --node_rank ${RANK:-0} \ 65 | --master_addr ${MASTER_ADDR:-127.0.0.1} \ 66 | --master_port ${MASTER_PORT:-12345} \ 67 | evaluate_caption.py \ 68 | --checkpoint $checkpoint \ 69 | --dataset $ds \ 70 | --batch-size 8 \ 71 | --num-workers 2 72 | ``` 73 | 74 |
75 | 76 | ## [COCO](https://cocodataset.org/) 77 | 78 | > COCO images are used in VQAv2/OK-VQA/RefCOCO/RefCOCO+/RefCOCOg, make sure you have already downloaded COCO images before evaluate on these benchmarks. 79 | 80 |
81 | Data Preparation 82 | 83 | ```bash 84 | mkdir -p data/coco && cd data/coco 85 | 86 | # download coco2014 images 87 | wget http://images.cocodataset.org/zips/train2014.zip && unzip train2014.zip 88 | wget http://images.cocodataset.org/zips/val2014.zip && unzip val2014.zip 89 | wget http://images.cocodataset.org/zips/test2015.zip && unzip test2015.zip 90 | 91 | cd ../.. 92 | ``` 93 | 94 |
95 | 96 | ## General VQA 97 | 98 | ### [VQAv2](https://visualqa.org/) 99 | 100 |
101 | Data Preparation 102 | 103 | ```bash 104 | mkdir -p data/vqav2 && cd data/vqav2 105 | 106 | # make sure you have downloaded COCO images 107 | 108 | # download questions and annotations 109 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip && unzip v2_Annotations_Train_mscoco.zip 110 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip && unzip v2_Questions_Train_mscoco.zip 111 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip && unzip v2_Annotations_Val_mscoco.zip 112 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip && unzip v2_Questions_Val_mscoco.zip 113 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip && unzip v2_Questions_Test_mscoco.zip 114 | 115 | # download converted files 116 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/vqav2/vqav2_train.jsonl 117 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/vqav2/vqav2_val.jsonl 118 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/vqav2/vqav2_testdev.jsonl 119 | ``` 120 | 121 |
122 | 123 |
124 | Evaluate 125 | 126 | ```bash 127 | checkpoint=/PATH/TO/CHECKPOINT 128 | for ds in "vqav2_val" "vqav2_testdev" 129 | python -m torch.distributed.launch --use-env \ 130 | --nproc_per_node ${NPROC_PER_NODE:-8} \ 131 | --nnodes ${WORLD_SIZE:-1} \ 132 | --node_rank ${RANK:-0} \ 133 | --master_addr ${MASTER_ADDR:-127.0.0.1} \ 134 | --master_port ${MASTER_PORT:-12345} \ 135 | evaluate_vqa.py \ 136 | --checkpoint $checkpoint \ 137 | --dataset $ds \ 138 | --batch-size 8 \ 139 | --num-workers 2 140 | ``` 141 | 142 |
143 | 144 | ### [OKVQA](https://okvqa.allenai.org/) 145 | 146 |
147 | Data Preparation 148 | 149 | ```bash 150 | mkdir -p data/okvqa && cd data/okvqa 151 | 152 | # download annotations and questions 153 | wget https://okvqa.allenai.org/static/data/mscoco_train2014_annotations.json.zip && unzip mscoco_train2014_annotations.json.zip 154 | wget https://okvqa.allenai.org/static/data/OpenEnded_mscoco_train2014_questions.json.zip && unzip OpenEnded_mscoco_train2014_questions.json.zip 155 | wget https://okvqa.allenai.org/static/data/mscoco_val2014_annotations.json.zip && unzip mscoco_val2014_annotations.json.zip 156 | wget https://okvqa.allenai.org/static/data/OpenEnded_mscoco_val2014_questions.json.zip && unzip OpenEnded_mscoco_val2014_questions.json.zip 157 | 158 | # download converted files 159 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/okvqa/okvqa_train.jsonl 160 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/okvqa/okvqa_val.jsonl 161 | 162 | cd ../.. 163 | ``` 164 | 165 |
166 | 167 |
168 | Evaluate 169 | 170 | ```bash 171 | ds="okvqa_val" 172 | checkpoint=/PATH/TO/CHECKPOINT 173 | python -m torch.distributed.launch --use-env \ 174 | --nproc_per_node ${NPROC_PER_NODE:-8} \ 175 | --nnodes ${WORLD_SIZE:-1} \ 176 | --node_rank ${RANK:-0} \ 177 | --master_addr ${MASTER_ADDR:-127.0.0.1} \ 178 | --master_port ${MASTER_PORT:-12345} \ 179 | evaluate_vqa.py \ 180 | --checkpoint $checkpoint \ 181 | --dataset $ds \ 182 | --batch-size 8 \ 183 | --num-workers 2 184 | ``` 185 | 186 |
187 | 188 | ### [TextVQA](https://textvqa.org/) 189 | 190 |
191 | Data Preparation 192 | 193 | ```bash 194 | mkdir -p data/textvqa && cd data/textvqa 195 | 196 | # download images 197 | wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip && unzip train_val_images.zip 198 | 199 | # download annotations and questions 200 | wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_train.json 201 | wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json 202 | 203 | # download converted files 204 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_train_annotations.json 205 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_train_questions.json 206 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_train.jsonl 207 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_val_annotations.json 208 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_val_questions.json 209 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_val.jsonl 210 | 211 | cd ../.. 212 | ``` 213 |
214 | 215 |
216 | Evaluate 217 | 218 | ```bash 219 | ds="textvqa_val" 220 | checkpoint=/PATH/TO/CHECKPOINT 221 | python -m torch.distributed.launch --use-env \ 222 | --nproc_per_node ${NPROC_PER_NODE:-8} \ 223 | --nnodes ${WORLD_SIZE:-1} \ 224 | --node_rank ${RANK:-0} \ 225 | --master_addr ${MASTER_ADDR:-127.0.0.1} \ 226 | --master_port ${MASTER_PORT:-12345} \ 227 | evaluate_vqa.py \ 228 | --checkpoint $checkpoint \ 229 | --dataset $ds \ 230 | --batch-size 8 \ 231 | --num-workers 2 232 | ``` 233 | 234 |
235 | 236 | 237 | ## Multi-modal Benchmark 238 | ### [MMBench](https://textvqa.org/) 239 | 240 |
241 | Data Preparation 242 | 243 | ```bash 244 | python mmbench_converter.py # To convert jsonl file. 245 | ``` 246 |
247 | 248 |
249 | Evaluate 250 | 251 | ```bash 252 | ds="mmbench_dev_20230712" 253 | checkpoint=/PATH/TO/CHECKPOINT 254 | python -m torch.distributed.launch --use-env \ 255 | --nproc_per_node ${NPROC_PER_NODE:-8} \ 256 | --nnodes ${WORLD_SIZE:-1} \ 257 | --node_rank ${RANK:-0} \ 258 | --master_addr ${MASTER_ADDR:-127.0.0.1} \ 259 | --master_port ${MASTER_PORT:-12345} \ 260 | evaluate_mmbench.py \ 261 | --checkpoint $checkpoint \ 262 | --dataset $ds \ 263 | --batch-size 1 \ 264 | --num-workers 2 265 | ``` 266 | 267 |
-------------------------------------------------------------------------------- /mplug_owl2/evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/evaluate/__init__.py -------------------------------------------------------------------------------- /mplug_owl2/evaluate/evaluate_caption.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | 9 | import torch 10 | from pycocoevalcap.eval import COCOEvalCap 11 | from pycocotools.coco import COCO 12 | from tqdm import tqdm 13 | from PIL import Image 14 | 15 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 16 | from mplug_owl2.conversation import conv_templates, SeparatorStyle 17 | from mplug_owl2.model.builder import load_pretrained_model 18 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 19 | 20 | ds_collections = { 21 | 'flickr': { 22 | 'train': 'data/flickr30k/flickr30k_karpathy_test.json', 23 | 'test': 'data/flickr30k/flickr30k_karpathy_test.json', 24 | }, 25 | } 26 | 27 | class CaptionDataset(torch.utils.data.Dataset): 28 | 29 | def __init__(self, train, test, prompt, image_processor, few_shot=0): 30 | self.images = json.load(open(test))['images'] 31 | self.prompt = prompt 32 | self.image_processor = image_processor 33 | 34 | self.few_shot = few_shot 35 | if few_shot > 0: 36 | self.train = json.load(open(train))['annotations'] 37 | 38 | def __len__(self): 39 | return len(self.images) 40 | 41 | def __getitem__(self, idx): 42 | image_id, image_path = self.images[idx]['id'], self.images[idx]['image'] 43 | 44 | image = Image.open(image_path).convert('RGB') 45 | max_edge = max(image.size) 46 | image = image.resize((max_edge, max_edge)) 47 | image_tensor = process_images([image], self.image_processor) 48 | 49 | return { 50 | 'image_id': image_id, 51 | 'image_tensor': image_tensor, 52 | 'input_text': self.prompt.format(image_path) 53 | } 54 | 55 | 56 | def collate_fn(inputs, tokenizer): 57 | image_ids = [_['image_id'] for _ in inputs] 58 | image_tensor = [_['image_tensor'] for _ in inputs] 59 | input_texts = [_['input_text'] for _ in inputs] 60 | 61 | input_ids = [] 62 | for input_text in input_texts: 63 | input_ids.append(tokenizer_image_token(input_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').tolist()) 64 | input_tokens_max_length = max([len(x) for x in input_ids]) 65 | pad_token_id = tokenizer.pad_token_id 66 | 67 | input_ids = [([pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left 68 | input_ids = torch.LongTensor(input_ids) 69 | attention_mask = 1 - input_ids.eq(pad_token_id).long() 70 | 71 | image_tensor = torch.cat(image_tensor, dim=0) 72 | return image_ids, image_tensor, input_ids, attention_mask 73 | 74 | 75 | class InferenceSampler(torch.utils.data.sampler.Sampler): 76 | 77 | def __init__(self, size): 78 | self._size = int(size) 79 | assert size > 0 80 | self._rank = torch.distributed.get_rank() 81 | self._world_size = torch.distributed.get_world_size() 82 | self._local_indices = self._get_local_indices(size, self._world_size, 83 | self._rank) 84 | 85 | @staticmethod 86 | def _get_local_indices(total_size, world_size, rank): 87 | shard_size = total_size // world_size 88 | left = total_size % world_size 89 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 90 | 91 | begin = sum(shard_sizes[:rank]) 92 | end = min(sum(shard_sizes[:rank + 1]), total_size) 93 | return range(begin, end) 94 | 95 | def __iter__(self): 96 | yield from self._local_indices 97 | 98 | def __len__(self): 99 | return len(self._local_indices) 100 | 101 | 102 | if __name__ == '__main__': 103 | 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--checkpoint', type=str, default='') 106 | parser.add_argument('--dataset', type=str, default='flickr') 107 | parser.add_argument('--batch-size', type=int, default=1) 108 | parser.add_argument('--num-workers', type=int, default=1) 109 | parser.add_argument('--few-shot', type=int, default=0) 110 | parser.add_argument('--seed', type=int, default=0) 111 | args = parser.parse_args() 112 | 113 | torch.distributed.init_process_group( 114 | backend='nccl', 115 | world_size=int(os.getenv('WORLD_SIZE', '1')), 116 | rank=int(os.getenv('RANK', '0')), 117 | ) 118 | 119 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 120 | 121 | prompt = 'USER: <|image|>Provide a one-sentence caption for the provided image. ASSISTANT: ' 122 | 123 | model_path = args.checkpoint 124 | model_name = get_model_name_from_path(model_path) 125 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda") 126 | tokenizer.padding_side = 'left' 127 | if not hasattr(tokenizer, 'pad_token_id'): 128 | tokenizer.pad_token_id = tokenizer.eos_token_id 129 | 130 | random.seed(args.seed) 131 | dataset = CaptionDataset( 132 | train=ds_collections[args.dataset]['train'], 133 | test=ds_collections[args.dataset]['test'], 134 | prompt=prompt, 135 | image_processor=image_processor, 136 | few_shot=args.few_shot, 137 | ) 138 | coco_karpathy_test_loader = torch.utils.data.DataLoader( 139 | dataset=dataset, 140 | sampler=InferenceSampler(len(dataset)), 141 | batch_size=args.batch_size, 142 | num_workers=args.num_workers, 143 | pin_memory=True, 144 | drop_last=False, 145 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 146 | ) 147 | 148 | image_ids = [] 149 | captions = [] 150 | for _, (ids, image_tensor, input_ids, attention_mask) in enumerate(tqdm(coco_karpathy_test_loader)): 151 | pred = model.generate( 152 | input_ids=input_ids.cuda(), 153 | attention_mask=attention_mask.cuda(), 154 | images=image_tensor.to(dtype=model.dtype).cuda(), 155 | do_sample=False, 156 | num_beams=1, 157 | max_new_tokens=60, 158 | min_new_tokens=8, 159 | length_penalty=0, 160 | num_return_sequences=1, 161 | use_cache=True, 162 | ) 163 | image_ids.extend(ids) 164 | captions.extend([ 165 | tokenizer.decode(_[input_ids.size(1):].cpu(), 166 | skip_special_tokens=True).strip() for _ in pred 167 | ]) 168 | print(captions[-len(pred):]) 169 | 170 | torch.distributed.barrier() 171 | 172 | world_size = torch.distributed.get_world_size() 173 | merged_ids = [None for _ in range(world_size)] 174 | merged_captions = [None for _ in range(world_size)] 175 | torch.distributed.all_gather_object(merged_ids, image_ids) 176 | torch.distributed.all_gather_object(merged_captions, captions) 177 | 178 | merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)] 179 | merged_captions = [ 180 | _ for _ in itertools.chain.from_iterable(merged_captions) 181 | ] 182 | 183 | if torch.distributed.get_rank() == 0: 184 | print(f"Evaluating {args.dataset} ...") 185 | 186 | results = [] 187 | for image_id, caption in zip(merged_ids, merged_captions): 188 | results.append({ 189 | 'image_id': int(image_id), 190 | 'caption': caption, 191 | }) 192 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 193 | results_file = f'{args.dataset}_{time_prefix}.json' 194 | json.dump(results, open(results_file, 'w')) 195 | 196 | coco = COCO(ds_collections[args.dataset]['test']) 197 | coco_result = coco.loadRes(results_file) 198 | coco_eval = COCOEvalCap(coco, coco_result) 199 | coco_eval.evaluate() 200 | 201 | print(coco_eval.eval.items()) 202 | torch.distributed.barrier() -------------------------------------------------------------------------------- /mplug_owl2/evaluate/evaluate_mmbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | from typing import Optional 9 | 10 | import torch 11 | from tqdm import tqdm 12 | from PIL import Image 13 | import pandas as pd 14 | 15 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 16 | from mplug_owl2.conversation import conv_templates, SeparatorStyle 17 | from mplug_owl2.model.builder import load_pretrained_model 18 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 19 | 20 | 21 | ds_collections = { 22 | 'mmbench_dev_20230712': { 23 | 'raw_file': 'mmbench_dev_20230712.tsv', 24 | 'annotation': 'mmbench_dev_20230712.jsonl', 25 | 'max_new_tokens': 10, 26 | }, 27 | 'mmbench_test_20230712': { 28 | 'raw_file': 'mmbench_test_20230712.tsv', 29 | 'annotation': 'mmbench_test_20230712.jsonl', 30 | 'max_new_tokens': 10, 31 | }, 32 | 'mmbench_test_en_20231003': { 33 | 'raw_file': 'mmbench_test_en_20231003.tsv', 34 | 'annotation': 'mmbench_test_en_20231003.jsonl', 35 | 'max_new_tokens': 10, 36 | }, 37 | 'mmbench_test_cn_20231003': { 38 | 'raw_file': 'mmbench_test_cn_20231003.tsv', 39 | 'annotation': 'mmbench_test_cn_20231003.jsonl', 40 | 'max_new_tokens': 10, 41 | }, 42 | 'ccbench_1003': { 43 | 'raw_file': 'ccbench_1003.tsv', 44 | 'annotation': 'ccbench_1003.jsonl', 45 | 'max_new_tokens': 10, 46 | }, 47 | } 48 | 49 | multiple_choices = ['A', 'B', 'C', 'D', 'E'] 50 | 51 | def mapping_to_annotation(results, raw_annotation): 52 | outputs = [] 53 | for result in results: 54 | index, prediction = result['index'], result['prediction'] 55 | row_df = raw_annotation[raw_annotation['index'] == index].squeeze().to_dict() 56 | output = { 57 | "index": index, 58 | "image": row_df['image'], 59 | "question": row_df['question'], 60 | "answer": row_df.get('answer', None), 61 | "options": [y for y in [row_df.get(x, None) for x in 'ABCD'] if isinstance(y, str)], 62 | "prediction": prediction, 63 | "l2-category": row_df['l2-category'] if 'l2-category' in row_df else None 64 | } 65 | outputs.append(output) 66 | return outputs 67 | 68 | 69 | def generate_submission_file(results, raw_annotation): 70 | outputs = [] 71 | for result in results: 72 | index, prediction = result['index'], result['prediction'] 73 | row_df = raw_annotation[raw_annotation['index'] == index].squeeze().to_dict() 74 | output = { 75 | "index": index, 76 | "question": row_df['question'], 77 | "prediction": prediction, 78 | "A": row_df.get('A', None), 79 | "B": row_df.get('B', None), 80 | "C": row_df.get('C', None), 81 | "D": row_df.get('D', None), 82 | } 83 | outputs.append(output) 84 | return outputs 85 | 86 | def collate_fn(batches, tokenizer): 87 | 88 | questions = [_['question'] for _ in batches] 89 | indices = [_['index'] for _ in batches] 90 | 91 | image_tensor = [_['image_tensor'] for _ in batches] 92 | 93 | input_ids = [] 94 | for input_text in questions: 95 | input_ids.append(tokenizer_image_token(input_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').tolist()) 96 | input_tokens_max_length = max([len(x) for x in input_ids]) 97 | pad_token_id = tokenizer.pad_token_id 98 | 99 | input_ids = [([pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left 100 | input_ids = torch.LongTensor(input_ids) 101 | attention_mask = 1 - input_ids.eq(pad_token_id).long() 102 | 103 | image_tensor = torch.cat(image_tensor, dim=0) 104 | return image_tensor, input_ids, attention_mask, indices 105 | 106 | 107 | class VQADataset(torch.utils.data.Dataset): 108 | 109 | def __init__(self, test, prompt, image_processor): 110 | 111 | self.prompt = prompt 112 | self.image_processor = image_processor 113 | 114 | self.data = open(test).readlines() 115 | 116 | def __len__(self): 117 | return len(self.data) 118 | 119 | def __getitem__(self, idx): 120 | data = json.loads(self.data[idx].strip()) 121 | index = data['index'] 122 | image = data['image'] 123 | hint = data['hint'] if data['hint'] else 'N/A' 124 | question = data['question'] 125 | 126 | choices = data['choices'] 127 | choice_list = [] 128 | for i, c in enumerate(choices): 129 | choice_list.append('{}. {}'.format(multiple_choices[i], c)) 130 | choice_txt = '\n'.join(choice_list) 131 | 132 | image = Image.open(image).convert('RGB') 133 | max_edge = max(image.size) 134 | image = image.resize((max_edge, max_edge)) # Resize here for best performance 135 | image_tensor = process_images([image], self.image_processor) 136 | 137 | return { 138 | 'index': index, 139 | 'image_tensor': image_tensor, 140 | 'question': self.prompt.format(hint, question, choice_txt), 141 | } 142 | 143 | 144 | class InferenceSampler(torch.utils.data.sampler.Sampler): 145 | 146 | def __init__(self, size): 147 | self._size = int(size) 148 | assert size > 0 149 | self._rank = torch.distributed.get_rank() 150 | self._world_size = torch.distributed.get_world_size() 151 | self._local_indices = self._get_local_indices(size, self._world_size, 152 | self._rank) 153 | 154 | @staticmethod 155 | def _get_local_indices(total_size, world_size, rank): 156 | shard_size = total_size // world_size 157 | left = total_size % world_size 158 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 159 | 160 | begin = sum(shard_sizes[:rank]) 161 | end = min(sum(shard_sizes[:rank + 1]), total_size) 162 | return range(begin, end) 163 | 164 | def __iter__(self): 165 | yield from self._local_indices 166 | 167 | def __len__(self): 168 | return len(self._local_indices) 169 | 170 | 171 | if __name__ == '__main__': 172 | 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument('--checkpoint', type=str, default='') 175 | parser.add_argument('--dataset', type=str, default='mmbench_dev_20230712') 176 | parser.add_argument('--batch-size', type=int, default=1) 177 | parser.add_argument('--num-workers', type=int, default=1) 178 | parser.add_argument('--seed', type=int, default=0) 179 | args = parser.parse_args() 180 | 181 | 182 | torch.distributed.init_process_group( 183 | backend='nccl', 184 | world_size=int(os.getenv('WORLD_SIZE', '1')), 185 | rank=int(os.getenv('RANK', '0')), 186 | ) 187 | 188 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 189 | os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('LOCAL_RANK', "0") 190 | 191 | model_path = args.checkpoint 192 | model_name = get_model_name_from_path(model_path) 193 | 194 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda") 195 | tokenizer.padding_side = 'left' 196 | if not hasattr(tokenizer, 'pad_token_id'): 197 | tokenizer.pad_token_id = tokenizer.eos_token_id 198 | 199 | prompt = "USER: <|image|>{}\n{}\n{}\nAnswer with the option’s letter from the given choices directly. ASSISTANT: " 200 | 201 | random.seed(args.seed) 202 | dataset = VQADataset( 203 | test=ds_collections[args.dataset]['annotation'], 204 | prompt=prompt, 205 | image_processor=image_processor, 206 | ) 207 | 208 | dataloader = torch.utils.data.DataLoader( 209 | dataset=dataset, 210 | sampler=InferenceSampler(len(dataset)), 211 | batch_size=args.batch_size, 212 | num_workers=args.num_workers, 213 | pin_memory=True, 214 | drop_last=False, 215 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 216 | ) 217 | 218 | outputs = [] 219 | for _, (image_tensor, input_ids, attention_mask, indices) in enumerate(tqdm(dataloader)): 220 | pred = model.generate( 221 | input_ids=input_ids.cuda(), 222 | attention_mask=attention_mask.cuda(), 223 | images=image_tensor.to(dtype=model.dtype).cuda(), 224 | do_sample=False, 225 | num_beams=1, 226 | max_new_tokens=ds_collections[args.dataset]['max_new_tokens'], 227 | min_new_tokens=1, 228 | length_penalty=1, 229 | num_return_sequences=1, 230 | output_hidden_states=True, 231 | use_cache=True, 232 | ) 233 | answers = [ 234 | tokenizer.decode(_[input_ids.size(1):].cpu(), 235 | skip_special_tokens=True).strip() for _ in pred 236 | ] 237 | 238 | for index, answer in zip(indices, answers): 239 | outputs.append({ 240 | 'index': index, 241 | 'prediction': answer, 242 | }) 243 | 244 | torch.distributed.barrier() 245 | 246 | world_size = torch.distributed.get_world_size() 247 | merged_outputs = [None for _ in range(world_size)] 248 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 249 | 250 | merged_outputs = [json.loads(_) for _ in merged_outputs] 251 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 252 | 253 | if torch.distributed.get_rank() == 0: 254 | print(f"Evaluating {args.dataset} ...") 255 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 256 | results_file = f'{args.dataset}_{time_prefix}_s{args.seed}.json' 257 | json.dump(merged_outputs, open(results_file, 'w'), ensure_ascii=False) 258 | 259 | mapped_result = mapping_to_annotation(merged_outputs, pd.read_csv(ds_collections[args.dataset]['raw_file'], sep='\t')) 260 | 261 | submission_res = generate_submission_file(merged_outputs, pd.read_csv(ds_collections[args.dataset]['raw_file'], sep='\t')) 262 | res_df = pd.DataFrame(submission_res) 263 | metrics_file = f'{args.dataset}_{time_prefix}_s{args.seed}_submission.xlsx' 264 | res_df.to_excel(metrics_file, index=False) 265 | 266 | torch.distributed.barrier() 267 | -------------------------------------------------------------------------------- /mplug_owl2/evaluate/evaluate_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | from typing import Optional 9 | 10 | import torch 11 | from tqdm import tqdm 12 | from PIL import Image 13 | 14 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 15 | from mplug_owl2.conversation import conv_templates, SeparatorStyle 16 | from mplug_owl2.model.builder import load_pretrained_model 17 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 18 | 19 | from vqa import VQA 20 | from vqa_eval import VQAEval 21 | 22 | ds_collections = { 23 | 'vqav2_val': { 24 | 'train': 'data/vqav2/vqav2_train.jsonl', 25 | 'test': 'data/vqav2/vqav2_val.jsonl', 26 | 'question': 'data/vqav2/v2_OpenEnded_mscoco_val2014_questions.json', 27 | 'annotation': 'data/vqav2/v2_mscoco_val2014_annotations.json', 28 | 'metric': 'vqa_score', 29 | 'max_new_tokens': 10, 30 | }, 31 | 'vqav2_testdev': { 32 | 'train': 'data/vqav2/vqav2_train.jsonl', 33 | 'test': 'data/vqav2/vqav2_testdev.jsonl', 34 | 'metric': None, 35 | 'max_new_tokens': 10, 36 | }, 37 | 'okvqa_val': { 38 | 'train': 'data/okvqa/okvqa_train.jsonl', 39 | 'test': 'data/okvqa/okvqa_val.jsonl', 40 | 'question': 'data/okvqa/OpenEnded_mscoco_val2014_questions.json', 41 | 'annotation': 'data/okvqa/mscoco_val2014_annotations.json', 42 | 'metric': 'vqa_score', 43 | 'max_new_tokens': 10, 44 | }, 45 | 'textvqa_val': { 46 | 'train': 'data/textvqa/textvqa_train.jsonl', 47 | 'test': 'data/textvqa/textvqa_val.jsonl', 48 | 'question': 'data/textvqa/textvqa_val_questions.json', 49 | 'annotation': 'data/textvqa/textvqa_val_annotations.json', 50 | 'metric': 'vqa_score', 51 | 'max_new_tokens': 10, 52 | }, 53 | } 54 | 55 | def collate_fn(batches, tokenizer): 56 | 57 | questions = [_['question'] for _ in batches] 58 | question_ids = [_['question_id'] for _ in batches] 59 | annotations = [_['annotation'] for _ in batches] 60 | 61 | image_tensor = [_['image_tensor'] for _ in batches] 62 | 63 | input_ids = [] 64 | for input_text in questions: 65 | input_ids.append(tokenizer_image_token(input_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').tolist()) 66 | input_tokens_max_length = max([len(x) for x in input_ids]) 67 | pad_token_id = tokenizer.pad_token_id 68 | 69 | input_ids = [([pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left 70 | input_ids = torch.LongTensor(input_ids) 71 | attention_mask = 1 - input_ids.eq(pad_token_id).long() 72 | 73 | image_tensor = torch.cat(image_tensor, dim=0) 74 | return question_ids, image_tensor, input_ids, attention_mask, annotations 75 | 76 | 77 | class VQADataset(torch.utils.data.Dataset): 78 | 79 | def __init__(self, train, test, prompt, image_processor, few_shot): 80 | self.test = json.load(open(test)) 81 | self.prompt = prompt 82 | self.image_processor = image_processor 83 | 84 | self.few_shot = few_shot 85 | if few_shot > 0: 86 | self.train = open(train).readlines() 87 | 88 | def __len__(self): 89 | return len(self.test) 90 | 91 | def __getitem__(self, idx): 92 | data = self.test[idx] 93 | image, question, question_id, annotation = data['image'], data[ 94 | 'question'], data['question_id'], data.get('answer', None) 95 | image = Image.open(image).convert('RGB') 96 | max_edge = max(image.size) 97 | image = image.resize((max_edge, max_edge)) # Resize here for best performance 98 | image_tensor = process_images([image], self.image_processor) 99 | 100 | return { 101 | 'image_tensor': image_tensor, 102 | 'question': self.prompt.format(question), 103 | 'question_id': question_id, 104 | 'annotation': annotation 105 | } 106 | 107 | 108 | class InferenceSampler(torch.utils.data.sampler.Sampler): 109 | 110 | def __init__(self, size): 111 | self._size = int(size) 112 | assert size > 0 113 | self._rank = torch.distributed.get_rank() 114 | self._world_size = torch.distributed.get_world_size() 115 | self._local_indices = self._get_local_indices(size, self._world_size, 116 | self._rank) 117 | 118 | @staticmethod 119 | def _get_local_indices(total_size, world_size, rank): 120 | shard_size = total_size // world_size 121 | left = total_size % world_size 122 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 123 | 124 | begin = sum(shard_sizes[:rank]) 125 | end = min(sum(shard_sizes[:rank + 1]), total_size) 126 | return range(begin, end) 127 | 128 | def __iter__(self): 129 | yield from self._local_indices 130 | 131 | def __len__(self): 132 | return len(self._local_indices) 133 | 134 | 135 | if __name__ == '__main__': 136 | 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--checkpoint', type=str, default='') 139 | parser.add_argument('--dataset', type=str, default='textvqa_val') 140 | parser.add_argument('--batch-size', type=int, default=1) 141 | parser.add_argument('--num-workers', type=int, default=1) 142 | parser.add_argument('--few-shot', type=int, default=0) 143 | parser.add_argument('--seed', type=int, default=0) 144 | args = parser.parse_args() 145 | 146 | torch.distributed.init_process_group( 147 | backend='nccl', 148 | world_size=int(os.getenv('WORLD_SIZE', '1')), 149 | rank=int(os.getenv('RANK', '0')), 150 | ) 151 | 152 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 153 | 154 | model_path = args.checkpoint 155 | model_name = get_model_name_from_path(model_path) 156 | 157 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda") 158 | tokenizer.padding_side = 'left' 159 | if not hasattr(tokenizer, 'pad_token_id'): 160 | tokenizer.pad_token_id = tokenizer.eos_token_id 161 | 162 | prompt = 'USER: <|image|>{} Answer the question using a single word or phrase. ASSISTANT: ' 163 | answer_processor = EvalAIAnswerProcessor() 164 | random.seed(args.seed) 165 | dataset = VQADataset( 166 | train=ds_collections[args.dataset]['train'], 167 | test=ds_collections[args.dataset]['test'], 168 | prompt=prompt, 169 | image_processor=image_processor, 170 | few_shot=args.few_shot, 171 | ) 172 | 173 | dataloader = torch.utils.data.DataLoader( 174 | dataset=dataset, 175 | sampler=InferenceSampler(len(dataset)), 176 | batch_size=args.batch_size, 177 | num_workers=args.num_workers, 178 | pin_memory=True, 179 | drop_last=False, 180 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 181 | ) 182 | 183 | outputs = [] 184 | for _, (question_ids, image_tensor, input_ids, attention_mask, 185 | annotations) in enumerate(tqdm(dataloader)): 186 | pred = model.generate( 187 | input_ids=input_ids.cuda(), 188 | attention_mask=attention_mask.cuda(), 189 | images=image_tensor.to(dtype=model.dtype).cuda(), 190 | do_sample=False, 191 | num_beams=1, 192 | max_new_tokens=ds_collections[args.dataset]['max_new_tokens'], 193 | min_new_tokens=1, 194 | length_penalty=1, 195 | num_return_sequences=1, 196 | output_hidden_states=True, 197 | use_cache=True, 198 | ) 199 | answers = [ 200 | tokenizer.decode(_[input_ids.size(1):].cpu(), 201 | skip_special_tokens=True).strip() for _ in pred 202 | ] 203 | 204 | for question_id, answer, annotation in zip(question_ids, answers, 205 | annotations): 206 | if args.dataset in ['vqav2_val', 'okvqa_val', 'textvqa_val']: 207 | outputs.append({ 208 | 'question_id': question_id, 209 | 'answer': answer, 210 | }) 211 | elif args.dataset == 'vqav2_testdev': 212 | outputs.append({ 213 | 'question_id': question_id, 214 | 'answer': answer_processor(answer), 215 | }) 216 | else: 217 | raise NotImplementedError 218 | 219 | torch.distributed.barrier() 220 | 221 | world_size = torch.distributed.get_world_size() 222 | merged_outputs = [None for _ in range(world_size)] 223 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 224 | 225 | merged_outputs = [json.loads(_) for _ in merged_outputs] 226 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 227 | 228 | if torch.distributed.get_rank() == 0: 229 | print(f"Evaluating {args.dataset} ...") 230 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 231 | results_file = f'{args.dataset}_{time_prefix}_fs{args.few_shot}_s{args.seed}.json' 232 | json.dump(merged_outputs, open(results_file, 'w', encoding='utf-8'), ensure_ascii=False) 233 | 234 | if ds_collections[args.dataset]['metric'] == 'vqa_score': 235 | vqa = VQA(ds_collections[args.dataset]['annotation'], 236 | ds_collections[args.dataset]['question']) 237 | results = vqa.loadRes( 238 | resFile=results_file, 239 | quesFile=ds_collections[args.dataset]['question']) 240 | vqa_scorer = VQAEval(vqa, results, n=2) 241 | vqa_scorer.evaluate() 242 | 243 | print(vqa_scorer.accuracy) 244 | torch.distributed.barrier() 245 | -------------------------------------------------------------------------------- /mplug_owl2/evaluate/mmbench_converter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import io 3 | import base64 4 | import json 5 | from PIL import Image 6 | 7 | ''' 8 | This scripts convert mmbench_dev tsv file to jsonl 9 | ''' 10 | 11 | datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t') 12 | 13 | global_choices = ['A', 'B', 'C', 'D'] 14 | 15 | def decode_base64_to_image(base64_string): 16 | image_data = base64.b64decode(base64_string) 17 | image = Image.open(io.BytesIO(image_data)) 18 | return image 19 | 20 | 21 | with open('./data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.jsonl', 'w') as f: 22 | for idx in range(len(datas)): 23 | data = datas.iloc[idx] 24 | 25 | index = int(data['index']) 26 | question = data['question'] 27 | hint = data['hint'] if not pd.isna(data['hint']) else 'N/A' 28 | 29 | choices = [] 30 | for opt in global_choices: 31 | if pd.isna(data[opt]): 32 | continue 33 | choices.append(data[opt]) 34 | 35 | answer = global_choices.index(data['answer']) 36 | 37 | image = decode_base64_to_image(data['image']) 38 | image.save("data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index) 39 | 40 | f.write(json.dumps({ 41 | "index": index, 42 | "image": "data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index, 43 | "hint": hint, 44 | "question": question, 45 | "choices": choices, 46 | "answer": answer, 47 | }) + "\n") -------------------------------------------------------------------------------- /mplug_owl2/evaluate/vqa.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) 2022, salesforce.com, inc. 2 | 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | __author__ = 'aagrawal' 9 | __version__ = '0.9' 10 | 11 | # Interface for accessing the VQA dataset. 12 | 13 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 14 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 15 | 16 | # The following functions are defined: 17 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 18 | # getQuesIds - Get question ids that satisfy given filter conditions. 19 | # getImgIds - Get image ids that satisfy given filter conditions. 20 | # loadQA - Load questions and answers with the specified question ids. 21 | # showQA - Display the specified questions and answers. 22 | # loadRes - Load result file and create result object. 23 | 24 | # Help on each function can be accessed by: "help(COCO.function)" 25 | 26 | import copy 27 | import datetime 28 | import json 29 | 30 | 31 | class VQA: 32 | 33 | def __init__(self, annotation_file=None, question_file=None): 34 | """Constructor of VQA helper class for reading and visualizing 35 | questions and answers. 36 | 37 | :param annotation_file (str): location of VQA annotation file 38 | :return: 39 | """ 40 | # load dataset 41 | self.dataset = {} 42 | self.questions = {} 43 | self.qa = {} 44 | self.qqa = {} 45 | self.imgToQA = {} 46 | if not annotation_file == None and not question_file == None: 47 | print('loading VQA annotations and questions into memory...') 48 | time_t = datetime.datetime.utcnow() 49 | dataset = json.load(open(annotation_file, 'r')) 50 | questions = json.load(open(question_file, 'r')) 51 | self.dataset = dataset 52 | self.questions = questions 53 | self.createIndex() 54 | 55 | def createIndex(self): 56 | # create index 57 | print('creating index...') 58 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 59 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 60 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 61 | for ann in self.dataset['annotations']: 62 | imgToQA[ann['image_id']] += [ann] 63 | qa[ann['question_id']] = ann 64 | for ques in self.questions['questions']: 65 | qqa[ques['question_id']] = ques 66 | print('index created!') 67 | 68 | # create class members 69 | self.qa = qa 70 | self.qqa = qqa 71 | self.imgToQA = imgToQA 72 | 73 | def info(self): 74 | """Print information about the VQA annotation file. 75 | 76 | :return: 77 | """ 78 | for key, value in self.datset['info'].items(): 79 | print('%s: %s' % (key, value)) 80 | 81 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 82 | """Get question ids that satisfy given filter conditions. default skips 83 | that filter. 84 | 85 | :param imgIds (int array) : get question ids for given imgs 86 | quesTypes (str array) : get question ids for given question types 87 | ansTypes (str array) : get question ids for given answer types 88 | :return: ids (int array) : integer array of question ids 89 | """ 90 | imgIds = imgIds if type(imgIds) == list else [imgIds] 91 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 92 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 93 | 94 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 95 | anns = self.dataset['annotations'] 96 | else: 97 | if not len(imgIds) == 0: 98 | anns = sum( 99 | [ 100 | self.imgToQA[imgId] 101 | for imgId in imgIds if imgId in self.imgToQA 102 | ], 103 | [], 104 | ) 105 | else: 106 | anns = self.dataset['annotations'] 107 | anns = (anns if len(quesTypes) == 0 else 108 | [ann for ann in anns if ann['question_type'] in quesTypes]) 109 | anns = (anns if len(ansTypes) == 0 else 110 | [ann for ann in anns if ann['answer_type'] in ansTypes]) 111 | ids = [ann['question_id'] for ann in anns] 112 | return ids 113 | 114 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 115 | """Get image ids that satisfy given filter conditions. default skips 116 | that filter. 117 | 118 | :param quesIds (int array) : get image ids for given question ids 119 | quesTypes (str array) : get image ids for given question types 120 | ansTypes (str array) : get image ids for given answer types 121 | :return: ids (int array) : integer array of image ids 122 | """ 123 | quesIds = quesIds if type(quesIds) == list else [quesIds] 124 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 125 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 126 | 127 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 128 | anns = self.dataset['annotations'] 129 | else: 130 | if not len(quesIds) == 0: 131 | anns = sum([ 132 | self.qa[quesId] for quesId in quesIds if quesId in self.qa 133 | ], []) 134 | else: 135 | anns = self.dataset['annotations'] 136 | anns = (anns if len(quesTypes) == 0 else 137 | [ann for ann in anns if ann['question_type'] in quesTypes]) 138 | anns = (anns if len(ansTypes) == 0 else 139 | [ann for ann in anns if ann['answer_type'] in ansTypes]) 140 | ids = [ann['image_id'] for ann in anns] 141 | return ids 142 | 143 | def loadQA(self, ids=[]): 144 | """Load questions and answers with the specified question ids. 145 | 146 | :param ids (int array) : integer ids specifying question ids 147 | :return: qa (object array) : loaded qa objects 148 | """ 149 | if type(ids) == list: 150 | return [self.qa[id] for id in ids] 151 | elif type(ids) == int: 152 | return [self.qa[ids]] 153 | 154 | def showQA(self, anns): 155 | """Display the specified annotations. 156 | 157 | :param anns (array of object): annotations to display 158 | :return: None 159 | """ 160 | if len(anns) == 0: 161 | return 0 162 | for ann in anns: 163 | quesId = ann['question_id'] 164 | print('Question: %s' % (self.qqa[quesId]['question'])) 165 | for ans in ann['answers']: 166 | print('Answer %d: %s' % (ans['answer_id'], ans['answer'])) 167 | 168 | def loadRes(self, resFile, quesFile): 169 | """Load result file and return a result object. 170 | 171 | :param resFile (str) : file name of result file 172 | :return: res (obj) : result api object 173 | """ 174 | res = VQA() 175 | res.questions = json.load(open(quesFile)) 176 | res.dataset['info'] = copy.deepcopy(self.questions['info']) 177 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) 178 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) 179 | res.dataset['data_subtype'] = copy.deepcopy( 180 | self.questions['data_subtype']) 181 | res.dataset['license'] = copy.deepcopy(self.questions['license']) 182 | 183 | print('Loading and preparing results... ') 184 | time_t = datetime.datetime.utcnow() 185 | anns = json.load(open(resFile)) 186 | assert type(anns) == list, 'results is not an array of objects' 187 | annsQuesIds = [ann['question_id'] for ann in anns] 188 | assert set(annsQuesIds) == set( 189 | self.getQuesIds() 190 | ), 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 191 | for ann in anns: 192 | quesId = ann['question_id'] 193 | if res.dataset['task_type'] == 'Multiple Choice': 194 | assert ( 195 | ann['answer'] in self.qqa[quesId]['multiple_choices'] 196 | ), 'predicted answer is not one of the multiple choices' 197 | qaAnn = self.qa[quesId] 198 | ann['image_id'] = qaAnn['image_id'] 199 | ann['question_type'] = qaAnn['question_type'] 200 | ann['answer_type'] = qaAnn['answer_type'] 201 | print('DONE (t=%0.2fs)' % 202 | ((datetime.datetime.utcnow() - time_t).total_seconds())) 203 | 204 | res.dataset['annotations'] = anns 205 | res.createIndex() 206 | return res -------------------------------------------------------------------------------- /mplug_owl2/local_serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/local_serve/__init__.py -------------------------------------------------------------------------------- /mplug_owl2/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg -------------------------------------------------------------------------------- /mplug_owl2/local_serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/local_serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /mplug_owl2/local_serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | import json 7 | import time 8 | import threading 9 | import uuid 10 | 11 | import requests 12 | import torch 13 | from functools import partial 14 | 15 | from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL 16 | from mplug_owl2.utils import (build_logger, server_error_msg, 17 | pretty_print_semaphore) 18 | from mplug_owl2.model.builder import load_pretrained_model 19 | from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria 20 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 21 | from transformers import TextIteratorStreamer 22 | from threading import Thread 23 | 24 | GB = 1 << 30 25 | 26 | worker_id = str(uuid.uuid4())[:6] 27 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 28 | 29 | class ModelWorker: 30 | def __init__(self, model_path, model_base, model_name, load_8bit, load_4bit, device): 31 | self.worker_id = worker_id 32 | if model_path.endswith("/"): 33 | model_path = model_path[:-1] 34 | if model_name is None: 35 | model_paths = model_path.split("/") 36 | if model_paths[-1].startswith('checkpoint-'): 37 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 38 | else: 39 | self.model_name = model_paths[-1] 40 | else: 41 | self.model_name = model_name 42 | 43 | self.device = device 44 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") 45 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( 46 | model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device) 47 | self.is_multimodal = True 48 | 49 | @torch.inference_mode() 50 | def generate_stream(self, params): 51 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 52 | 53 | prompt = params["prompt"] 54 | ori_prompt = prompt 55 | images = params.get("images", None) 56 | num_image_tokens = 0 57 | if images is not None and len(images) > 0 and self.is_multimodal: 58 | if len(images) > 0: 59 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 60 | raise ValueError("Number of images does not match number of <|image|> tokens in prompt") 61 | 62 | images = [load_image_from_base64(image) for image in images] 63 | images = process_images(images, image_processor, model.config) 64 | 65 | if type(images) is list: 66 | images = [image.to(self.model.device, dtype=torch.float16) for image in images] 67 | else: 68 | images = images.to(self.model.device, dtype=torch.float16) 69 | 70 | replace_token = DEFAULT_IMAGE_TOKEN 71 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 72 | 73 | num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1) 74 | else: 75 | images = None 76 | image_args = {"images": images} 77 | else: 78 | images = None 79 | image_args = {} 80 | 81 | temperature = float(params.get("temperature", 1.0)) 82 | top_p = float(params.get("top_p", 1.0)) 83 | max_context_length = getattr(model.config, 'max_position_embeddings', 4096) 84 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 85 | stop_str = params.get("stop", None) 86 | do_sample = True if temperature > 0.001 else False 87 | 88 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) 89 | keywords = [stop_str] 90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 91 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 92 | 93 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 94 | 95 | if max_new_tokens < 1: 96 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 97 | return 98 | 99 | thread = Thread(target=model.generate, kwargs=dict( 100 | inputs=input_ids, 101 | do_sample=do_sample, 102 | temperature=temperature, 103 | top_p=top_p, 104 | max_new_tokens=max_new_tokens, 105 | streamer=streamer, 106 | stopping_criteria=[stopping_criteria], 107 | use_cache=True, 108 | **image_args 109 | )) 110 | thread.start() 111 | 112 | generated_text = ori_prompt 113 | for new_text in streamer: 114 | generated_text += new_text 115 | if generated_text.endswith(stop_str): 116 | generated_text = generated_text[:-len(stop_str)] 117 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() 118 | 119 | def generate_stream_gate(self, params): 120 | try: 121 | for x in self.generate_stream(params): 122 | yield x 123 | except ValueError as e: 124 | print("Caught ValueError:", e) 125 | ret = { 126 | "text": server_error_msg, 127 | "error_code": 1, 128 | } 129 | yield json.dumps(ret).encode() 130 | except torch.cuda.CudaError as e: 131 | print("Caught torch.cuda.CudaError:", e) 132 | ret = { 133 | "text": server_error_msg, 134 | "error_code": 1, 135 | } 136 | yield json.dumps(ret).encode() 137 | except Exception as e: 138 | print("Caught Unknown Error", e) 139 | ret = { 140 | "text": server_error_msg, 141 | "error_code": 1, 142 | } 143 | yield json.dumps(ret).encode() -------------------------------------------------------------------------------- /mplug_owl2/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 8 | from icecream import ic 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def expand2square(pil_img, background_color): 16 | width, height = pil_img.size 17 | if width == height: 18 | return pil_img 19 | elif width > height: 20 | result = Image.new(pil_img.mode, (width, width), background_color) 21 | result.paste(pil_img, (0, (width - height) // 2)) 22 | return result 23 | else: 24 | result = Image.new(pil_img.mode, (height, height), background_color) 25 | result.paste(pil_img, ((height - width) // 2, 0)) 26 | return result 27 | 28 | 29 | def process_images(images, image_processor, model_cfg=None): 30 | if model_cfg is not None: 31 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 32 | else: 33 | image_aspect_ratio = 'resize' 34 | new_images = [] 35 | if image_aspect_ratio == 'pad': 36 | for image in images: 37 | image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) 38 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 39 | new_images.append(image) 40 | elif image_aspect_ratio == 'resize': 41 | for image in images: 42 | max_edge = max(image.size) 43 | image = image.resize((max_edge, max_edge)) 44 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 45 | new_images.append(image) 46 | else: 47 | return image_processor(images, return_tensors='pt')['pixel_values'] 48 | if all(x.shape == new_images[0].shape for x in new_images): 49 | new_images = torch.stack(new_images, dim=0) 50 | return new_images 51 | 52 | 53 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 54 | prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in 55 | prompt.split(DEFAULT_IMAGE_TOKEN)] 56 | 57 | def insert_separator(X, sep): 58 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 59 | 60 | input_ids = [] 61 | offset = 0 62 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 63 | offset = 1 64 | input_ids.append(prompt_chunks[0][0]) 65 | 66 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 67 | input_ids.extend(x[offset:]) 68 | 69 | if return_tensors is not None: 70 | if return_tensors == 'pt': 71 | return torch.tensor(input_ids, dtype=torch.long) 72 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 73 | return input_ids 74 | 75 | 76 | def get_model_name_from_path(model_path): 77 | model_path = model_path.strip("/") 78 | model_paths = model_path.split("/") 79 | if model_paths[-1].startswith('checkpoint-'): 80 | return model_paths[-2] + "_" + model_paths[-1] 81 | else: 82 | return model_paths[-1] 83 | 84 | 85 | class KeywordsStoppingCriteria(StoppingCriteria): 86 | def __init__(self, keywords, tokenizer, input_ids): 87 | self.keywords = keywords 88 | self.keyword_ids = [] 89 | self.max_keyword_len = 0 90 | for keyword in keywords: 91 | cur_keyword_ids = tokenizer(keyword).input_ids 92 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 93 | cur_keyword_ids = cur_keyword_ids[1:] 94 | if len(cur_keyword_ids) > self.max_keyword_len: 95 | self.max_keyword_len = len(cur_keyword_ids) 96 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 97 | self.tokenizer = tokenizer 98 | self.start_len = input_ids.shape[1] 99 | 100 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 101 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 102 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 103 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 104 | for keyword_id in self.keyword_ids: 105 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 106 | return True 107 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 108 | for keyword in self.keywords: 109 | if keyword in outputs: 110 | return True 111 | return False 112 | -------------------------------------------------------------------------------- /mplug_owl2/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM, MPLUGOwl2QWenForCausalLM 2 | from .configuration_mplug_owl2 import MPLUGOwl2Config,MPLUGOwl2QwenConfig 3 | -------------------------------------------------------------------------------- /mplug_owl2/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | from transformers.models.clip.image_processing_clip import CLIPImageProcessor 22 | import torch 23 | from extensions.sd_smartprocess.mplug_owl2.model import * 24 | from icecream import ic 25 | 26 | from extensions.sd_smartprocess.mplug_owl2 import MPLUGOwl2LlamaForCausalLM 27 | from extensions.sd_smartprocess.mplug_owl2.model import MPLUGOwl2QWenForCausalLM 28 | 29 | 30 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", 31 | device="cuda", **kwargs): 32 | kwargs = {"device_map": device_map, "ignore_mismatched_sizes": False, **kwargs} 33 | 34 | if device != "cuda": 35 | kwargs['device_map'] = {"": device} 36 | 37 | if load_8bit: 38 | kwargs['load_in_8bit'] = True 39 | elif load_4bit: 40 | kwargs['load_in_4bit'] = True 41 | kwargs['quantization_config'] = BitsAndBytesConfig( 42 | load_in_4bit=True, 43 | bnb_4bit_compute_dtype=torch.float16, 44 | bnb_4bit_use_double_quant=True, 45 | bnb_4bit_quant_type='nf4' 46 | ) 47 | else: 48 | kwargs['torch_dtype'] = torch.float16 49 | if 'mplug_owl2' in model_name.lower(): 50 | # Load LLaVA2 model 51 | if 'lora' in model_name.lower() and model_base is None: 52 | warnings.warn( 53 | 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 54 | if 'lora' in model_name.lower() and model_base is not None: 55 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 56 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 57 | print('Loading mPLUG-Owl2 from base model...') 58 | if 'mplug_owl2_1' in model_name.lower(): 59 | model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, 60 | config=lora_cfg_pretrained, **kwargs) 61 | else: 62 | model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, 63 | config=lora_cfg_pretrained, **kwargs) 64 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 65 | if model.lm_head.weight.shape[0] != token_num: 66 | model.lm_head.weight = torch.nn.Parameter( 67 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 68 | model.model.embed_tokens.weight = torch.nn.Parameter( 69 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 70 | 71 | print('Loading additional mPLUG-Owl2 weights...') 72 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 73 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), 74 | map_location='cpu') 75 | else: 76 | # this is probably from HF Hub 77 | from huggingface_hub import hf_hub_download 78 | def load_from_hf(repo_id, filename, subfolder=None): 79 | cache_file = hf_hub_download( 80 | repo_id=repo_id, 81 | filename=filename, 82 | subfolder=subfolder) 83 | return torch.load(cache_file, map_location='cpu') 84 | 85 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 86 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in 87 | non_lora_trainables.items()} 88 | if any(k.startswith('model.model.') for k in non_lora_trainables): 89 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in 90 | non_lora_trainables.items()} 91 | model.load_state_dict(non_lora_trainables, strict=False) 92 | 93 | from peft import PeftModel 94 | print('Loading LoRA weights...') 95 | model = PeftModel.from_pretrained(model, model_path) 96 | print('Merging LoRA weights...') 97 | model = model.merge_and_unload() 98 | print('Model is loaded...') 99 | elif model_base is not None: 100 | # this may be mm projector only 101 | print('Loading mPLUG-Owl2 from base model...') 102 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 103 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 104 | if 'mplug_owl2_1' in model_name.lower(): 105 | model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, 106 | config=cfg_pretrained, **kwargs) 107 | else: 108 | model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, 109 | config=cfg_pretrained, **kwargs) 110 | else: 111 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) 112 | if 'mplug_owl2_1' in model_name.lower(): 113 | model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 114 | else: 115 | model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 116 | else: 117 | # Load language model 118 | if model_base is not None: 119 | # PEFT model 120 | from peft import PeftModel 121 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True) 122 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 123 | print(f"Loading LoRA weights from {model_path}") 124 | model = PeftModel.from_pretrained(model, model_path) 125 | print(f"Merging weights") 126 | model = model.merge_and_unload() 127 | print('Convert to FP16...') 128 | model.to(torch.float16) 129 | else: 130 | use_fast = False 131 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) 132 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 133 | 134 | vision_tower = model.get_model().vision_model 135 | # vision_tower.to(device=device, dtype=torch.float16) 136 | image_processor = CLIPImageProcessor.from_pretrained(model_path) 137 | 138 | if hasattr(model.config, "max_sequence_length"): 139 | context_len = model.config.max_sequence_length 140 | else: 141 | context_len = 2048 142 | 143 | return tokenizer, model, image_processor, context_len 144 | -------------------------------------------------------------------------------- /mplug_owl2/model/configuration_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import PretrainedConfig 7 | 8 | 9 | class QWenConfig(PretrainedConfig): 10 | model_type = "qwen" 11 | keys_to_ignore_at_inference = ["past_key_values"] 12 | 13 | def __init__( 14 | self, 15 | multiway=False, 16 | vocab_size=151936, 17 | hidden_size=4096, 18 | num_hidden_layers=32, 19 | num_attention_heads=32, 20 | emb_dropout_prob=0.0, 21 | attn_dropout_prob=0.0, 22 | layer_norm_epsilon=1e-6, 23 | initializer_range=0.02, 24 | max_position_embeddings=8192, 25 | scale_attn_weights=True, 26 | use_cache=True, 27 | bf16=False, 28 | fp16=False, 29 | fp32=False, 30 | kv_channels=128, 31 | rotary_pct=1.0, 32 | rotary_emb_base=10000, 33 | use_dynamic_ntk=True, 34 | use_logn_attn=True, 35 | use_flash_attn="auto", 36 | intermediate_size=22016, 37 | no_bias=True, 38 | tie_word_embeddings=False, 39 | use_cache_quantization=False, 40 | use_cache_kernel=False, 41 | softmax_in_fp32=False, 42 | **kwargs, 43 | ): 44 | self.multiway = multiway 45 | self.vocab_size = vocab_size 46 | self.hidden_size = hidden_size 47 | self.intermediate_size = intermediate_size 48 | self.num_hidden_layers = num_hidden_layers 49 | self.num_attention_heads = num_attention_heads 50 | self.emb_dropout_prob = emb_dropout_prob 51 | self.attn_dropout_prob = attn_dropout_prob 52 | self.layer_norm_epsilon = layer_norm_epsilon 53 | self.initializer_range = initializer_range 54 | self.scale_attn_weights = scale_attn_weights 55 | self.use_cache = use_cache 56 | self.max_position_embeddings = max_position_embeddings 57 | self.bf16 = bf16 58 | self.fp16 = fp16 59 | self.fp32 = fp32 60 | self.kv_channels = kv_channels 61 | self.rotary_pct = rotary_pct 62 | self.rotary_emb_base = rotary_emb_base 63 | self.use_dynamic_ntk = use_dynamic_ntk 64 | self.use_logn_attn = use_logn_attn 65 | self.use_flash_attn = use_flash_attn 66 | self.no_bias = no_bias 67 | self.use_cache_quantization = use_cache_quantization 68 | self.use_cache_kernel = use_cache_kernel 69 | self.softmax_in_fp32 = softmax_in_fp32 70 | super().__init__( 71 | tie_word_embeddings=tie_word_embeddings, 72 | **kwargs 73 | ) -------------------------------------------------------------------------------- /mplug_owl2/model/modeling_attn_mask_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import torch 17 | 18 | 19 | class AttentionMaskConverter: 20 | """ 21 | A utility attention mask class that allows one to: 22 | - Create a causal 4d mask 23 | - Create a causal 4d mask with slided window 24 | - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, 25 | key_value_length) that can be multiplied with attention scores 26 | 27 | Parameters: 28 | is_causal (`bool`): 29 | Whether the attention mask should be a uni-directional (causal) or bi-directional mask. 30 | 31 | sliding_window (`int`, *optional*): 32 | Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. 33 | """ 34 | 35 | def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): 36 | self.is_causal = is_causal 37 | self.sliding_window = sliding_window 38 | 39 | if self.sliding_window is not None and self.sliding_window <= 0: 40 | raise ValueError( 41 | f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" 42 | ) 43 | 44 | def to_causal_4d( 45 | self, 46 | batch_size: int, 47 | query_length: int, 48 | key_value_length: int, 49 | dtype: torch.dtype = torch.float32, 50 | device: Union[torch.device, "str"] = "cpu", 51 | ) -> torch.Tensor: 52 | """ 53 | Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative 54 | bias to upper right hand triangular matrix (causal mask). 55 | """ 56 | if not self.is_causal: 57 | raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") 58 | 59 | # If shape is not cached, create a new causal mask and cache it 60 | input_shape = (batch_size, query_length) 61 | past_key_values_length = key_value_length - query_length 62 | 63 | # create causal mask 64 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 65 | causal_4d_mask = None 66 | if input_shape[-1] > 1 or self.sliding_window is not None: 67 | causal_4d_mask = self._make_causal_mask( 68 | input_shape, 69 | dtype, 70 | device=device, 71 | past_key_values_length=past_key_values_length, 72 | sliding_window=self.sliding_window, 73 | ) 74 | 75 | return causal_4d_mask 76 | 77 | def to_4d( 78 | self, 79 | attention_mask_2d: torch.Tensor, 80 | query_length: int, 81 | key_value_length: Optional[int] = None, 82 | dtype: torch.dtype = torch.float32, 83 | ) -> torch.Tensor: 84 | """ 85 | Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, 86 | key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is 87 | causal, a causal mask will be added. 88 | """ 89 | input_shape = (attention_mask_2d.shape[0], query_length) 90 | 91 | # create causal mask 92 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 93 | causal_4d_mask = None 94 | if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: 95 | if key_value_length is None: 96 | raise ValueError( 97 | "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." 98 | ) 99 | 100 | past_key_values_length = key_value_length - query_length 101 | causal_4d_mask = self._make_causal_mask( 102 | input_shape, 103 | dtype, 104 | device=attention_mask_2d.device, 105 | past_key_values_length=past_key_values_length, 106 | sliding_window=self.sliding_window, 107 | ) 108 | elif self.sliding_window is not None: 109 | raise NotImplementedError("Sliding window is currently only implemented for causal masking") 110 | 111 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 112 | expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( 113 | attention_mask_2d.device 114 | ) 115 | expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask 116 | 117 | return expanded_4d_mask 118 | 119 | @staticmethod 120 | def _make_causal_mask( 121 | input_ids_shape: torch.Size, 122 | dtype: torch.dtype, 123 | device: torch.device, 124 | past_key_values_length: int = 0, 125 | sliding_window: Optional[int] = None, 126 | ): 127 | """ 128 | Make causal mask used for bi-directional self-attention. 129 | """ 130 | bsz, tgt_len = input_ids_shape 131 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 132 | mask_cond = torch.arange(mask.size(-1), device=device) 133 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 134 | 135 | mask = mask.to(dtype) 136 | 137 | if past_key_values_length > 0: 138 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 139 | 140 | # add lower triangular sliding window mask if necessary 141 | if sliding_window is not None: 142 | diagonal = past_key_values_length - sliding_window + 1 143 | 144 | context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) 145 | mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) 146 | 147 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 148 | 149 | @staticmethod 150 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 151 | """ 152 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 153 | """ 154 | bsz, src_len = mask.size() 155 | tgt_len = tgt_len if tgt_len is not None else src_len 156 | 157 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 158 | 159 | inverted_mask = 1.0 - expanded_mask 160 | 161 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 162 | 163 | 164 | def _prepare_4d_causal_attention_mask( 165 | attention_mask: Optional[torch.Tensor], 166 | input_shape: Union[torch.Size, Tuple, List], 167 | inputs_embeds: torch.Tensor, 168 | past_key_values_length: int, 169 | sliding_window: Optional[int] = None, 170 | ): 171 | """ 172 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 173 | `(batch_size, key_value_length)` 174 | 175 | Args: 176 | attention_mask (`torch.Tensor` or `None`): 177 | A 2D attention mask of shape `(batch_size, key_value_length)` 178 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`): 179 | The input shape should be a tuple that defines `(batch_size, query_length)`. 180 | inputs_embeds (`torch.Tensor`): 181 | The embedded inputs as a torch Tensor. 182 | past_key_values_length (`int`): 183 | The length of the key value cache. 184 | sliding_window (`int`, *optional*): 185 | If the model uses windowed attention, a sliding window should be passed. 186 | """ 187 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) 188 | 189 | key_value_length = input_shape[-1] + past_key_values_length 190 | 191 | # 4d mask is passed through the layers 192 | if attention_mask is not None: 193 | attention_mask = attn_mask_converter.to_4d( 194 | attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype 195 | ) 196 | else: 197 | attention_mask = attn_mask_converter.to_causal_4d( 198 | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device 199 | ) 200 | 201 | return attention_mask 202 | 203 | 204 | def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 205 | """ 206 | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 207 | `(batch_size, key_value_length)` 208 | 209 | Args: 210 | mask (`torch.Tensor` or `None`): 211 | A 2D attention mask of shape `(batch_size, key_value_length)` 212 | dtype (`torch.dtype`): 213 | The torch dtype the created mask shall have. 214 | tgt_len (`int`): 215 | The target length or query length the created mask shall have. 216 | """ 217 | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) 218 | 219 | 220 | def _create_4d_causal_attention_mask( 221 | input_shape: Union[torch.Size, Tuple, List], 222 | dtype: torch.dtype, 223 | device: torch.device, 224 | past_key_values_length: int = 0, 225 | sliding_window: Optional[int] = None, 226 | ): 227 | """ 228 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` 229 | 230 | Args: 231 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`): 232 | The input shape should be a tuple that defines `(batch_size, query_length)`. 233 | dtype (`torch.dtype`): 234 | The torch dtype the created mask shall have. 235 | device (`int`): 236 | The torch device the created mask shall have. 237 | sliding_window (`int`, *optional*): 238 | If the model uses windowed attention, a sliding window should be passed. 239 | """ 240 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) 241 | 242 | key_value_length = past_key_values_length + input_shape[-1] 243 | attention_mask = attn_mask_converter.to_causal_4d( 244 | input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device 245 | ) 246 | 247 | return attention_mask -------------------------------------------------------------------------------- /mplug_owl2/model/multiway.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.utils.checkpoint 4 | from torch import nn 5 | 6 | 7 | class MultiwayNetwork(nn.Module): 8 | 9 | def __init__(self, module_provider, num_multiway=2, out_features=None): 10 | super(MultiwayNetwork, self).__init__() 11 | 12 | self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)]) 13 | self.out_features=out_features 14 | def forward(self, hidden_states, multiway_indices): 15 | 16 | if len(self.multiway) == 1: 17 | return self.multiway[0](hidden_states) 18 | if self.out_features: 19 | output_hidden_states = torch.empty( 20 | hidden_states.size(0), hidden_states.size(1), self.out_features, 21 | dtype=hidden_states.dtype 22 | ).to(hidden_states.device) 23 | else: 24 | output_hidden_states = torch.empty_like(hidden_states) 25 | for idx, subway in enumerate(self.multiway): 26 | local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True) 27 | hidden = hidden_states[local_indices].unsqueeze(1).contiguous() 28 | if hidden.numel(): 29 | output = subway(hidden) 30 | if isinstance(output, tuple): 31 | output = output[0] 32 | output = output.squeeze(1) 33 | output_hidden_states[local_indices] = output 34 | 35 | return output_hidden_states.contiguous() -------------------------------------------------------------------------------- /mplug_owl2/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'mplug_owl2' in config and 'mplug_owl2' not in cfg.model_type: 7 | assert cfg.model_type == 'mplug_owl2' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "mplug_owl2") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) -------------------------------------------------------------------------------- /mplug_owl2/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/serve/__init__.py -------------------------------------------------------------------------------- /mplug_owl2/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 5 | from mplug_owl2.conversation import conv_templates, SeparatorStyle 6 | from mplug_owl2.model.builder import load_pretrained_model 7 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 8 | 9 | from PIL import Image 10 | 11 | import requests 12 | from PIL import Image 13 | from io import BytesIO 14 | from transformers import TextStreamer 15 | 16 | 17 | def disable_torch_init(): 18 | """ 19 | Disable the redundant torch default initialization to accelerate model creation. 20 | """ 21 | import torch 22 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 23 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 24 | 25 | 26 | def load_image(image_file): 27 | if image_file.startswith('http://') or image_file.startswith('https://'): 28 | response = requests.get(image_file) 29 | image = Image.open(BytesIO(response.content)).convert('RGB') 30 | else: 31 | image = Image.open(image_file).convert('RGB') 32 | return image 33 | 34 | 35 | def main(args): 36 | # Model 37 | disable_torch_init() 38 | 39 | model_name = get_model_name_from_path(args.model_path) 40 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 41 | 42 | conv_mode = "mplug_owl2" 43 | 44 | if args.conv_mode is not None and conv_mode != args.conv_mode: 45 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 46 | else: 47 | args.conv_mode = conv_mode 48 | 49 | conv = conv_templates[args.conv_mode].copy() 50 | roles = conv.roles 51 | 52 | image = load_image(args.image_file) 53 | # Similar operation in model_worker.py 54 | image_tensor = process_images([image], image_processor, args) 55 | if type(image_tensor) is list: 56 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 57 | else: 58 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 59 | 60 | while True: 61 | try: 62 | inp = input(f"{roles[0]}: ") 63 | except EOFError: 64 | inp = "" 65 | if not inp: 66 | print("exit...") 67 | break 68 | 69 | print(f"{roles[1]}: ", end="") 70 | 71 | if image is not None: 72 | # first message 73 | inp = DEFAULT_IMAGE_TOKEN + inp 74 | conv.append_message(conv.roles[0], inp) 75 | image = None 76 | else: 77 | # later messages 78 | conv.append_message(conv.roles[0], inp) 79 | conv.append_message(conv.roles[1], None) 80 | prompt = conv.get_prompt() 81 | 82 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 83 | stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2 84 | keywords = [stop_str] 85 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 86 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 87 | 88 | with torch.inference_mode(): 89 | output_ids = model.generate( 90 | input_ids, 91 | images=image_tensor, 92 | do_sample=True, 93 | temperature=args.temperature, 94 | max_new_tokens=args.max_new_tokens, 95 | streamer=streamer, 96 | use_cache=True, 97 | stopping_criteria=[stopping_criteria]) 98 | 99 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 100 | conv.messages[-1][-1] = outputs 101 | 102 | if args.debug: 103 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 109 | parser.add_argument("--model-base", type=str, default=None) 110 | parser.add_argument("--image-file", type=str, required=True) 111 | parser.add_argument("--device", type=str, default="cuda") 112 | parser.add_argument("--conv-mode", type=str, default=None) 113 | parser.add_argument("--temperature", type=float, default=0.2) 114 | parser.add_argument("--max-new-tokens", type=int, default=512) 115 | parser.add_argument("--load-8bit", action="store_true") 116 | parser.add_argument("--load-4bit", action="store_true") 117 | parser.add_argument("--debug", action="store_true") 118 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 119 | args = parser.parse_args() 120 | main(args) -------------------------------------------------------------------------------- /mplug_owl2/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import asyncio 7 | import dataclasses 8 | from enum import Enum, auto 9 | import json 10 | import logging 11 | import time 12 | from typing import List, Union 13 | import threading 14 | 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import StreamingResponse 17 | import numpy as np 18 | import requests 19 | import uvicorn 20 | 21 | from mplug_owl2.constants import CONTROLLER_HEART_BEAT_EXPIRATION 22 | from mplug_owl2.utils import build_logger, server_error_msg 23 | 24 | 25 | logger = build_logger("controller", "controller.log") 26 | 27 | 28 | class DispatchMethod(Enum): 29 | LOTTERY = auto() 30 | SHORTEST_QUEUE = auto() 31 | 32 | @classmethod 33 | def from_str(cls, name): 34 | if name == "lottery": 35 | return cls.LOTTERY 36 | elif name == "shortest_queue": 37 | return cls.SHORTEST_QUEUE 38 | else: 39 | raise ValueError(f"Invalid dispatch method") 40 | 41 | 42 | @dataclasses.dataclass 43 | class WorkerInfo: 44 | model_names: List[str] 45 | speed: int 46 | queue_length: int 47 | check_heart_beat: bool 48 | last_heart_beat: str 49 | 50 | 51 | def heart_beat_controller(controller): 52 | while True: 53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 54 | controller.remove_stable_workers_by_expiration() 55 | 56 | 57 | class Controller: 58 | def __init__(self, dispatch_method: str): 59 | # Dict[str -> WorkerInfo] 60 | self.worker_info = {} 61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 62 | 63 | self.heart_beat_thread = threading.Thread( 64 | target=heart_beat_controller, args=(self,)) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, 70 | worker_status: dict): 71 | if worker_name not in self.worker_info: 72 | logger.info(f"Register a new worker: {worker_name}") 73 | else: 74 | logger.info(f"Register an existing worker: {worker_name}") 75 | 76 | if not worker_status: 77 | worker_status = self.get_worker_status(worker_name) 78 | if not worker_status: 79 | return False 80 | 81 | self.worker_info[worker_name] = WorkerInfo( 82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 83 | check_heart_beat, time.time()) 84 | 85 | logger.info(f"Register done: {worker_name}, {worker_status}") 86 | return True 87 | 88 | def get_worker_status(self, worker_name: str): 89 | try: 90 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 91 | except requests.exceptions.RequestException as e: 92 | logger.error(f"Get status fails: {worker_name}, {e}") 93 | return None 94 | 95 | if r.status_code != 200: 96 | logger.error(f"Get status fails: {worker_name}, {r}") 97 | return None 98 | 99 | return r.json() 100 | 101 | def remove_worker(self, worker_name: str): 102 | del self.worker_info[worker_name] 103 | 104 | def refresh_all_workers(self): 105 | old_info = dict(self.worker_info) 106 | self.worker_info = {} 107 | 108 | for w_name, w_info in old_info.items(): 109 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 110 | logger.info(f"Remove stale worker: {w_name}") 111 | 112 | def list_models(self): 113 | model_names = set() 114 | 115 | for w_name, w_info in self.worker_info.items(): 116 | model_names.update(w_info.model_names) 117 | 118 | return list(model_names) 119 | 120 | def get_worker_address(self, model_name: str): 121 | if self.dispatch_method == DispatchMethod.LOTTERY: 122 | worker_names = [] 123 | worker_speeds = [] 124 | for w_name, w_info in self.worker_info.items(): 125 | if model_name in w_info.model_names: 126 | worker_names.append(w_name) 127 | worker_speeds.append(w_info.speed) 128 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 129 | norm = np.sum(worker_speeds) 130 | if norm < 1e-4: 131 | return "" 132 | worker_speeds = worker_speeds / norm 133 | if True: # Directly return address 134 | pt = np.random.choice(np.arange(len(worker_names)), 135 | p=worker_speeds) 136 | worker_name = worker_names[pt] 137 | return worker_name 138 | 139 | # Check status before returning 140 | while True: 141 | pt = np.random.choice(np.arange(len(worker_names)), 142 | p=worker_speeds) 143 | worker_name = worker_names[pt] 144 | 145 | if self.get_worker_status(worker_name): 146 | break 147 | else: 148 | self.remove_worker(worker_name) 149 | worker_speeds[pt] = 0 150 | norm = np.sum(worker_speeds) 151 | if norm < 1e-4: 152 | return "" 153 | worker_speeds = worker_speeds / norm 154 | continue 155 | return worker_name 156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 157 | worker_names = [] 158 | worker_qlen = [] 159 | for w_name, w_info in self.worker_info.items(): 160 | if model_name in w_info.model_names: 161 | worker_names.append(w_name) 162 | worker_qlen.append(w_info.queue_length / w_info.speed) 163 | if len(worker_names) == 0: 164 | return "" 165 | min_index = np.argmin(worker_qlen) 166 | w_name = worker_names[min_index] 167 | self.worker_info[w_name].queue_length += 1 168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 169 | return w_name 170 | else: 171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 172 | 173 | def receive_heart_beat(self, worker_name: str, queue_length: int): 174 | if worker_name not in self.worker_info: 175 | logger.info(f"Receive unknown heart beat. {worker_name}") 176 | return False 177 | 178 | self.worker_info[worker_name].queue_length = queue_length 179 | self.worker_info[worker_name].last_heart_beat = time.time() 180 | logger.info(f"Receive heart beat. {worker_name}") 181 | return True 182 | 183 | def remove_stable_workers_by_expiration(self): 184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 185 | to_delete = [] 186 | for worker_name, w_info in self.worker_info.items(): 187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 188 | to_delete.append(worker_name) 189 | 190 | for worker_name in to_delete: 191 | self.remove_worker(worker_name) 192 | 193 | def worker_api_generate_stream(self, params): 194 | worker_addr = self.get_worker_address(params["model"]) 195 | if not worker_addr: 196 | logger.info(f"no worker: {params['model']}") 197 | ret = { 198 | "text": server_error_msg, 199 | "error_code": 2, 200 | } 201 | yield json.dumps(ret).encode() + b"\0" 202 | 203 | try: 204 | response = requests.post(worker_addr + "/worker_generate_stream", 205 | json=params, stream=True, timeout=5) 206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 207 | if chunk: 208 | yield chunk + b"\0" 209 | except requests.exceptions.RequestException as e: 210 | logger.info(f"worker timeout: {worker_addr}") 211 | ret = { 212 | "text": server_error_msg, 213 | "error_code": 3, 214 | } 215 | yield json.dumps(ret).encode() + b"\0" 216 | 217 | 218 | # Let the controller act as a worker to achieve hierarchical 219 | # management. This can be used to connect isolated sub networks. 220 | def worker_api_get_status(self): 221 | model_names = set() 222 | speed = 0 223 | queue_length = 0 224 | 225 | for w_name in self.worker_info: 226 | worker_status = self.get_worker_status(w_name) 227 | if worker_status is not None: 228 | model_names.update(worker_status["model_names"]) 229 | speed += worker_status["speed"] 230 | queue_length += worker_status["queue_length"] 231 | 232 | return { 233 | "model_names": list(model_names), 234 | "speed": speed, 235 | "queue_length": queue_length, 236 | } 237 | 238 | 239 | app = FastAPI() 240 | 241 | 242 | @app.post("/register_worker") 243 | async def register_worker(request: Request): 244 | data = await request.json() 245 | controller.register_worker( 246 | data["worker_name"], data["check_heart_beat"], 247 | data.get("worker_status", None)) 248 | 249 | 250 | @app.post("/refresh_all_workers") 251 | async def refresh_all_workers(): 252 | models = controller.refresh_all_workers() 253 | 254 | 255 | @app.post("/list_models") 256 | async def list_models(): 257 | models = controller.list_models() 258 | return {"models": models} 259 | 260 | 261 | @app.post("/get_worker_address") 262 | async def get_worker_address(request: Request): 263 | data = await request.json() 264 | addr = controller.get_worker_address(data["model"]) 265 | return {"address": addr} 266 | 267 | 268 | @app.post("/receive_heart_beat") 269 | async def receive_heart_beat(request: Request): 270 | data = await request.json() 271 | exist = controller.receive_heart_beat( 272 | data["worker_name"], data["queue_length"]) 273 | return {"exist": exist} 274 | 275 | 276 | @app.post("/worker_generate_stream") 277 | async def worker_api_generate_stream(request: Request): 278 | params = await request.json() 279 | generator = controller.worker_api_generate_stream(params) 280 | return StreamingResponse(generator) 281 | 282 | 283 | @app.post("/worker_get_status") 284 | async def worker_api_get_status(request: Request): 285 | return controller.worker_api_get_status() 286 | 287 | 288 | if __name__ == "__main__": 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument("--host", type=str, default="localhost") 291 | parser.add_argument("--port", type=int, default=21001) 292 | parser.add_argument("--dispatch-method", type=str, choices=[ 293 | "lottery", "shortest_queue"], default="shortest_queue") 294 | args = parser.parse_args() 295 | logger.info(f"args: {args}") 296 | 297 | controller = Controller(args.dispatch_method) 298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") -------------------------------------------------------------------------------- /mplug_owl2/serve/examples/Rebecca_(1939_poster)_Small.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/serve/examples/Rebecca_(1939_poster)_Small.jpeg -------------------------------------------------------------------------------- /mplug_owl2/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /mplug_owl2/serve/register_workers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 -------------------------------------------------------------------------------- /mplug_owl2/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | modality_indicators: torch.Tensor, 20 | attention_mask: Optional[torch.Tensor] = None, 21 | position_ids: Optional[torch.Tensor] = None, 22 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 23 | output_attentions: bool = False, 24 | use_cache: bool = False, 25 | padding_mask: bool = None, 26 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 27 | if output_attentions: 28 | warnings.warn( 29 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 30 | ) 31 | 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | key_states = ( 40 | self.k_proj(hidden_states, modality_indicators) 41 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) 44 | value_states = ( 45 | self.v_proj(hidden_states, modality_indicators) 46 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 47 | .transpose(1, 2) 48 | ) # shape: (b, num_heads, s, head_dim) 49 | 50 | kv_seq_len = key_states.shape[-2] 51 | if past_key_value is not None: 52 | kv_seq_len += past_key_value[0].shape[-2] 53 | 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | query_states, key_states = apply_rotary_pos_emb( 56 | query_states, key_states, cos, sin, position_ids 57 | ) 58 | 59 | if past_key_value is not None: 60 | # reuse k, v 61 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 62 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 63 | 64 | past_key_value = (key_states, value_states) if use_cache else None 65 | 66 | # repeat k/v heads if n_kv_heads < n_heads 67 | key_states = repeat_kv(key_states, self.num_key_value_groups) 68 | value_states = repeat_kv(value_states, self.num_key_value_groups) 69 | 70 | # Transform the data into the format required by flash attention 71 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 72 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 73 | key_padding_mask = attention_mask 74 | 75 | if key_padding_mask is None: 76 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 77 | cu_q_lens = torch.arange( 78 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 79 | ) 80 | max_s = q_len 81 | output = flash_attn_unpadded_qkvpacked_func( 82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 83 | ) 84 | output = output.view(bsz, q_len, -1) 85 | else: 86 | qkv = qkv.reshape(bsz, q_len, -1) 87 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 88 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 89 | output_unpad = flash_attn_unpadded_qkvpacked_func( 90 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 91 | ) 92 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 93 | output = pad_input(output_unpad, indices, bsz, q_len) 94 | 95 | return self.o_proj(output), None, past_key_value 96 | 97 | 98 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 99 | # requires the attention mask to be the same as the key_padding_mask 100 | def _prepare_decoder_attention_mask( 101 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 102 | ): 103 | # [bsz, seq_len] 104 | return attention_mask 105 | 106 | 107 | def replace_llama_attn_with_flash_attn(): 108 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 109 | if cuda_major < 8: 110 | warnings.warn( 111 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 112 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 113 | ) 114 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 115 | _prepare_decoder_attention_mask 116 | ) 117 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward -------------------------------------------------------------------------------- /mplug_owl2/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from mplug_owl2.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from mplug_owl2.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() -------------------------------------------------------------------------------- /mplug_owl2/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from mplug_owl2.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" -------------------------------------------------------------------------------- /process_params.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Dict 3 | 4 | from extensions.sd_smartprocess.file_manager import ImageData 5 | 6 | 7 | @dataclass 8 | class ProcessParams: 9 | auto_save: bool = False 10 | blip_initial_prompt = "a caption for this image is: " 11 | booru_min_score: float = 0.75 12 | caption: bool = False 13 | captioners: Dict[str, bool] = field(default_factory=lambda: []) 14 | char_threshold: float = 0.5 15 | clip_append_artist: bool = False 16 | clip_append_flavor: bool = False 17 | clip_append_medium: bool = False 18 | clip_append_movement: bool = False 19 | clip_append_trending: bool = False 20 | clip_max_flavors: int = 3 21 | clip_use_v2: bool = False 22 | crop: bool = False 23 | crop_mode: str = "smart" 24 | do_backup: bool = False 25 | do_rename: bool = False 26 | dst: str = "" 27 | face_model: str = "Codeformers" 28 | flip: bool = False 29 | idefics2_prompt: str = "Describe this image in one detailed sentence, include the subject, location, style, and type of image." 30 | interrogation_prompt: str = "Describe this image in one detailed sentence." 31 | insert_subject: bool = False 32 | load_in_4bit: bool = False 33 | load_in_8bit: bool = False 34 | max_clip_tokens: float = 1.0 35 | max_size: int = 1024 36 | max_tokens: int = 75 37 | min_clip_tokens: float = 0.0 38 | new_caption: str = "" 39 | nl_captioners: Dict[str, bool] = field(default_factory=lambda: []) 40 | num_beams: int = 5 41 | pad: bool = False 42 | replace_blip_caption: bool = True 43 | replace_class: bool = False 44 | restore_faces: bool = False 45 | save_caption: bool = False 46 | save_image: bool = False 47 | src_files: List[ImageData] = field(default_factory=lambda: []) 48 | subject: str = "" 49 | subject_class: str = "" 50 | tags_to_ignore: List[str] = field(default_factory=lambda: []) 51 | threshold: float = 0.5 52 | txt_action: str = "ignore" 53 | upscale: bool = False 54 | upscale_max: int = 4096 55 | upscale_mode: str = "ratio" 56 | upscale_ratio: float = 2.0 57 | upscaler_1 = None 58 | upscaler_2 = None 59 | wd14_min_score: float = 0.75 60 | image_path = None 61 | 62 | def clip_params(self): 63 | return { 64 | "min_clip_tokens": self.min_clip_tokens, 65 | "max_clip_tokens": self.max_clip_tokens, 66 | "use_v2": self.clip_use_v2, 67 | "append_flavor": self.clip_append_flavor, 68 | "max_flavors": self.clip_max_flavors, 69 | "append_medium": self.clip_append_medium, 70 | "append_movement": self.clip_append_movement, 71 | "append_artist": self.clip_append_artist, 72 | "append_trending": self.clip_append_trending, 73 | "num_beams": self.num_beams, 74 | "clip_max_flavors": self.clip_max_flavors, 75 | "blip_initial_prompt": self.blip_initial_prompt 76 | } 77 | 78 | def pre_only(self): 79 | self.caption = False 80 | self.upscale = False 81 | self.restore_faces = False 82 | 83 | def cap_only(self): 84 | self.upscale = False 85 | self.restore_faces = False 86 | self.crop = False 87 | self.pad = False 88 | 89 | def post_only(self): 90 | self.caption = False 91 | self.crop = False 92 | self.pad = False 93 | 94 | @classmethod 95 | def from_dict(cls, d): 96 | instance = cls() # Get the singleton instance 97 | for k, v in d.items(): 98 | k = k.replace("sp_", "") # Adjust the attribute name 99 | if k == "class": 100 | k = "subject_class" 101 | if hasattr(instance, k): 102 | setattr(instance, k, v) 103 | return instance 104 | -------------------------------------------------------------------------------- /processors.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import random 3 | from typing import List, Any 4 | 5 | import torch 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm 10 | from extensions.sd_smartprocess import super_resolution 11 | from extensions.sd_smartprocess.clipcrop import CropClip 12 | from extensions.sd_smartprocess.interrogators.clip_interrogator import CLIPInterrogator 13 | from extensions.sd_smartprocess.interrogators.booru_interrogator import BooruInterrogator 14 | from modules import shared 15 | 16 | # Base processor 17 | class Processor: 18 | def __init__(self): 19 | printm("Model loaded.") 20 | 21 | # Unload models 22 | def unload(self): 23 | if torch.has_cuda: 24 | torch.cuda.empty_cache() 25 | gc.collect() 26 | printm("Model unloaded.") 27 | 28 | # Process images 29 | def process(self, images: List[Image.Image]) -> List[Any]: 30 | raise Exception("Not Implemented") 31 | 32 | # CLIP Processing 33 | class ClipProcessor(Processor): 34 | def __init__( 35 | self, 36 | clip_use_v2, 37 | clip_append_artist, 38 | clip_append_medium, 39 | clip_append_movement, 40 | clip_append_flavor, 41 | clip_append_trending, 42 | num_beams, 43 | min_clip_tokens, 44 | max_clip_tokens, 45 | max_flavors 46 | ): 47 | self.description = "Processing CLIP" 48 | if shared.interrogator is not None: 49 | shared.interrogator.unload() 50 | 51 | self.max_flavors = max_flavors 52 | shared.state.textinfo = "Loading CLIP Model..." 53 | self.model = CLIPInterrogator( 54 | clip_use_v2, 55 | clip_append_artist, 56 | clip_append_medium, 57 | clip_append_movement, 58 | clip_append_flavor, 59 | clip_append_trending, 60 | num_beams, 61 | min_clip_tokens, 62 | max_clip_tokens 63 | ) 64 | super().__init__() 65 | 66 | def process(self, images: List[Image.Image], short:bool=False) -> List[str]: 67 | output = [] 68 | shared.state.job_count = len(images) 69 | shared.state.textinfo = f"{self.description}..." 70 | for img in tqdm(images, desc=self.description): 71 | short_caption = self.model.interrogate(img, short=short, max_flavors=self.max_flavors) 72 | output.append(short_caption) 73 | shared.state.current_image = img 74 | shared.state.job_no += 1 75 | return output 76 | 77 | def unload(self): 78 | if self.model.clip_model: 79 | del self.model.clip_model 80 | if self.model.blip_model: 81 | del self.model.blip_model 82 | super().unload() 83 | 84 | # Danbooru Processing 85 | class BooruProcessor(Processor): 86 | def __init__(self, min_score: float): 87 | self.description = "Processing Danbooru" 88 | shared.state.textinfo = "Loading DeepDanbooru Model..." 89 | self.model = BooruInterrogator() 90 | self.min_score = min_score 91 | super().__init__() 92 | 93 | def process(self, images: List[Image.Image]) -> List[List[str]]: 94 | output = [] 95 | shared.state.job_count = len(images) 96 | shared.state.textinfo = f"{self.description}..." 97 | for img in tqdm(images, desc=self.description): 98 | out_tags = [] 99 | tags = self.model.interrogate(img) 100 | for tag in sorted(tags, key=tags.get, reverse=True): 101 | if tags[tag] >= self.min_score: 102 | out_tags.append(tag) 103 | output.append(out_tags) 104 | shared.state.job_count += 1 105 | 106 | def unload(self): 107 | self.model.unload() 108 | super().unload() 109 | 110 | # WD14 Processing 111 | 112 | # Crop Processing 113 | class CropProcessor(Processor): 114 | def __init__(self, subject_class: str, pad: bool, crop: bool): 115 | self.description = "Cropping" 116 | if crop: 117 | shared.state.textinfo = "Loading CROP Model..." 118 | self.model = CropClip() if crop else None 119 | self.subject_class = subject_class 120 | self.pad = pad 121 | self.crop = crop 122 | super().__init__() 123 | 124 | def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]: 125 | output = [] 126 | shared.state.job_count = len(images) 127 | shared.state.textinfo = f"{self.description}..." 128 | for img, caption in tqdm(zip(images, captions), desc=self.description): 129 | cropped = self._process_img(img, caption) 130 | output.append(cropped) 131 | shared.state.job_no += 1 132 | return output 133 | 134 | 135 | def _process_img(self, img, short_caption): 136 | if self.subject_class is not None and self.subject_class != "": 137 | short_caption = self.subject_class 138 | 139 | src_ratio = img.width / img.height 140 | 141 | # Pad image before cropping? 142 | if src_ratio != 1 and self.pad: 143 | if img.width > img.height: 144 | pad_width = img.width 145 | pad_height = img.width 146 | else: 147 | pad_width = img.height 148 | pad_height = img.height 149 | res = Image.new("RGB", (pad_width, pad_height)) 150 | res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2)) 151 | img = res 152 | 153 | if self.crop: 154 | # Do the actual crop clip 155 | im_data = self.model.get_center(img, prompt=short_caption) 156 | crop_width = im_data[1] - im_data[0] 157 | center_x = im_data[0] + (crop_width / 2) 158 | crop_height = im_data[3] - im_data[2] 159 | center_y = im_data[2] + (crop_height / 2) 160 | crop_ratio = crop_width / crop_height 161 | dest_ratio = 1 162 | tgt_width = crop_width 163 | tgt_height = crop_height 164 | 165 | if crop_ratio != dest_ratio: 166 | if crop_width > crop_height: 167 | tgt_height = crop_width / dest_ratio 168 | tgt_width = crop_width 169 | else: 170 | tgt_width = crop_height / dest_ratio 171 | tgt_height = crop_height 172 | 173 | # Reverse the above if dest is too big 174 | if tgt_width > img.width or tgt_height > img.height: 175 | if tgt_width > img.width: 176 | tgt_width = img.width 177 | tgt_height = tgt_width / dest_ratio 178 | else: 179 | tgt_height = img.height 180 | tgt_width = tgt_height / dest_ratio 181 | 182 | tgt_height = int(tgt_height) 183 | tgt_width = int(tgt_width) 184 | left = max(center_x - (tgt_width / 2), 0) 185 | right = min(center_x + (tgt_width / 2), img.width) 186 | top = max(center_y - (tgt_height / 2), 0) 187 | bottom = min(center_y + (tgt_height / 2), img.height) 188 | img = img.crop((left, top, right, bottom)) 189 | return img 190 | 191 | def unload(self): 192 | if self.model is not None: 193 | self.model.unload() 194 | super().unload() 195 | 196 | # Upscale Processing 197 | class UpscaleProcessor(Processor): 198 | def __init__(self): 199 | self.description = "Upscaling" 200 | shared.state.textinfo = "Loading Stable-Diffusion Upscaling Model..." 201 | self.sampler, self.model = super_resolution.initialize_model() 202 | super().__init__() 203 | 204 | def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]: 205 | output = [] 206 | shared.state.job_count = len(images) 207 | shared.state.textinfo = f"{self.description}..." 208 | for img, caption in tqdm(zip(images, captions), desc=self.description): 209 | seed = int(random.randrange(2147483647)) 210 | img = super_resolution.predict(self.sampler, img, caption, 75, 1, 10, seed, 0, 20) 211 | output.append(img) 212 | shared.state.job_no += 1 213 | return output 214 | 215 | def unload(self): 216 | del self.sampler 217 | del self.model 218 | super().unload() 219 | 220 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipython==8.18.0 2 | spandrel>=0.3.1 3 | seaborn==0.13.2 4 | tensorflow 5 | open_clip_torch 6 | icecream==2.1.3 7 | chardet 8 | cchardet 9 | peft 10 | flask 11 | ruamel.yaml 12 | markdown2 13 | sconf 14 | tensorboardX 15 | transformers>=4.40.0 -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | #sp_src { 2 | width: calc(100% - 84px); 3 | min-width: calc(100% - 84px); 4 | } 5 | 6 | #sp_clear_src, #sp_load_src { 7 | margin-top: 20px; 8 | height: 36px; 9 | } 10 | 11 | #sp_top_btn_row { 12 | padding-top: 20px; 13 | } 14 | 15 | #sp_clear_src, #sp_load_src { 16 | flex: none; 17 | max-width: 42px; 18 | min-width: 42px !important; 19 | height: 42px; 20 | line-height: 1em; 21 | border-radius: 0.5em; 22 | } 23 | 24 | #sp_backup_files { 25 | max-width: 120px !important; 26 | padding-top: 5px !important; 27 | margin-left: 15px; 28 | margin-right: 15px; 29 | min-width: 120px !important; 30 | } 31 | 32 | #sp_file_options { 33 | margin-left: 25%; 34 | } 35 | 36 | .inc_exc_btn { 37 | margin-top: 20px !important; 38 | } 39 | 40 | #sp_progressbar.hide-container, #sp_progress.hide-container, #sp_error.hide-container { 41 | min-height: 0; 42 | display: none; 43 | } 44 | 45 | .short_check { 46 | min-width: 120px !important; 47 | max-width: 120px !important; 48 | } -------------------------------------------------------------------------------- /super_resolution.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os.path 3 | 4 | import safetensors.torch 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | from omegaconf import OmegaConf 9 | from einops import repeat, rearrange 10 | from pytorch_lightning import seed_everything 11 | 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion 14 | from ldm.util import exists, instantiate_from_config 15 | 16 | from modules import shared, modelloader, sd_hijack 17 | 18 | torch.set_grad_enabled(False) 19 | 20 | 21 | def initialize_model(): 22 | model_dir = os.path.join(shared.models_path, "Upscale") 23 | if not os.path.exists(model_dir): 24 | os.makedirs(model_dir) 25 | 26 | model_name = os.path.join(model_dir, "x4-upscaler-ema.ckpt") 27 | 28 | model_url = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.ckpt" 29 | 30 | model_dir = os.path.join(shared.models_path, "Upscaler") 31 | model_file = modelloader.load_models(model_dir, model_url, None, '.ckpt', model_name) 32 | model_file = model_file[0] 33 | model_config = os.path.join(shared.script_path, "extensions", "sd_smartprocess", "configs", "upscaler.yaml") 34 | 35 | config = OmegaConf.load(model_config) 36 | model = instantiate_from_config(config.model) 37 | model_model = torch.load(model_file) 38 | model.load_state_dict(model_model["state_dict"], strict=False) 39 | model.half() 40 | 41 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 42 | model = model.to(device) 43 | sd_hijack.model_hijack.hijack(model) # apply optimization 44 | 45 | if shared.cmd_opts.opt_channelslast: 46 | model = model.to(memory_format=torch.channels_last) 47 | 48 | model.eval() 49 | sampler = DDIMSampler(model) 50 | return sampler, model 51 | 52 | 53 | def make_batch_sd( 54 | image, 55 | txt, 56 | device, 57 | num_samples=1, 58 | ): 59 | image = np.array(image.convert("RGB")) 60 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 61 | batch = { 62 | "lr": rearrange(image, 'h w c -> 1 c h w'), 63 | "txt": num_samples * [txt], 64 | } 65 | batch["lr"] = repeat(batch["lr"].to(device=device), 66 | "1 ... -> n ...", n=num_samples) 67 | return batch 68 | 69 | 70 | def make_noise_augmentation(model, batch, noise_level=None): 71 | x_low = batch[model.low_scale_key] 72 | x_low = x_low.to(memory_format=torch.contiguous_format).float() 73 | x_aug, noise_level = model.low_scale_model(x_low, noise_level) 74 | return x_aug, noise_level 75 | 76 | 77 | def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None): 78 | device = torch.device( 79 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 80 | model = sampler.model 81 | seed_everything(seed) 82 | prng = np.random.RandomState(seed) 83 | start_code = prng.randn(num_samples, model.channels, h, w) 84 | start_code = torch.from_numpy(start_code).to( 85 | device=device, dtype=torch.float32) 86 | 87 | with torch.no_grad(),\ 88 | torch.autocast("cuda"): 89 | batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) 90 | print(f"Batch: {batch}") 91 | c = model.cond_stage_model.encode(batch["txt"]) 92 | c_cat = list() 93 | if isinstance(model, LatentUpscaleFinetuneDiffusion): 94 | for ck in model.concat_keys: 95 | cc = batch[ck] 96 | if exists(model.reshuffle_patch_size): 97 | assert isinstance(model.reshuffle_patch_size, int) 98 | cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', 99 | p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size) 100 | c_cat.append(cc) 101 | c_cat = torch.cat(c_cat, dim=1) 102 | # cond 103 | cond = {"c_concat": [c_cat], "c_crossattn": [c]} 104 | # uncond cond 105 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 106 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} 107 | elif isinstance(model, LatentUpscaleDiffusion): 108 | x_augment, noise_level = make_noise_augmentation( 109 | model, batch, noise_level) 110 | cond = {"c_concat": [x_augment], 111 | "c_crossattn": [c], "c_adm": noise_level} 112 | # uncond cond 113 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 114 | uc_full = {"c_concat": [x_augment], "c_crossattn": [ 115 | uc_cross], "c_adm": noise_level} 116 | else: 117 | raise NotImplementedError() 118 | 119 | shape = [model.channels, h, w] 120 | samples, intermediates = sampler.sample( 121 | steps, 122 | num_samples, 123 | shape, 124 | cond, 125 | verbose=False, 126 | eta=eta, 127 | unconditional_guidance_scale=scale, 128 | unconditional_conditioning=uc_full, 129 | x_T=start_code, 130 | callback=callback 131 | ) 132 | with torch.no_grad(): 133 | x_samples_ddim = model.decode_first_stage(samples) 134 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 135 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 136 | return [Image.fromarray(img.astype(np.uint8)) for img in result] 137 | 138 | 139 | def pad_image(input_image): 140 | pad_w, pad_h = np.max(((2, 2), np.ceil( 141 | np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size 142 | im_padded = Image.fromarray( 143 | np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) 144 | return im_padded 145 | 146 | 147 | def predict(sampler, input_image, prompt, steps, num_samples, scale, seed, eta, noise_level): 148 | init_image = input_image.convert("RGB") 149 | image = pad_image(init_image) # resize to integer multiple of 32 150 | width, height = image.size 151 | 152 | noise_level = torch.Tensor( 153 | num_samples * [noise_level]).to(sampler.model.device).long() 154 | sampler.make_schedule(steps, ddim_eta=eta, verbose=True) 155 | result = paint( 156 | sampler=sampler, 157 | image=image, 158 | prompt=prompt, 159 | seed=seed, 160 | scale=scale, 161 | h=height, w=width, steps=steps, 162 | num_samples=num_samples, 163 | callback=None, 164 | noise_level=noise_level 165 | ) 166 | return result 167 | 168 | def super_resolution(self, images, steps=50, target_scale=2, half_attention=False): 169 | model = self.load_model_from_config(half_attention) 170 | gc.collect() 171 | if torch.cuda.is_available: 172 | torch.cuda.empty_cache() 173 | 174 | # Run settings 175 | diffusion_steps = int(steps) 176 | eta = 1.0 177 | output = [] 178 | for image in images: 179 | im_og = image 180 | width_og, height_og = im_og.size 181 | # If we can adjust the max upscale size, then the 4 below should be our variable 182 | down_sample_rate = target_scale / 4 183 | wd = width_og * down_sample_rate 184 | hd = height_og * down_sample_rate 185 | width_downsampled_pre = int(np.ceil(wd)) 186 | height_downsampled_pre = int(np.ceil(hd)) 187 | 188 | if down_sample_rate != 1: 189 | print( 190 | f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') 191 | im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) 192 | else: 193 | print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") 194 | 195 | # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts 196 | pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size 197 | im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) 198 | 199 | logs = self.run(model["model"], im_padded, diffusion_steps, eta) 200 | 201 | sample = logs["sample"] 202 | sample = sample.detach().cpu() 203 | sample = torch.clamp(sample, -1., 1.) 204 | sample = (sample + 1.) / 2. * 255 205 | sample = sample.numpy().astype(np.uint8) 206 | sample = np.transpose(sample, (0, 2, 3, 1)) 207 | a = Image.fromarray(sample[0]) 208 | 209 | # remove padding 210 | a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) 211 | output.append(a) 212 | 213 | del model 214 | gc.collect() 215 | if torch.cuda.is_available: 216 | torch.cuda.empty_cache() 217 | 218 | return output 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /upscalers/spandrel/spandrel_srformer_model.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from extensions.sd_smartprocess.upscalers.spandrel.spandrel_upscaler_base import SpandrelUpscaler 4 | from modules import modelloader 5 | from modules.upscaler import UpscalerData 6 | 7 | model_urls = [ 8 | "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth", 9 | 10 | ] 11 | 12 | models = { 13 | "SRformer 4xFF": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4xFrankendata_FullDegradation_g.pth", 14 | "SRFormer 4xNomos8kSC": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4xNomos8kSC_SRFormer.pth", 15 | "SRFormer FP": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/FrankendataPretrainer_SRFormer_g.pth", 16 | "SwinIR x8": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth", 17 | "4xLDSIR": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4xLSDIR_DAT.pth", 18 | "ClearRealityV1 4x": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4x-ClearRealityV1_SPAN.pth" 19 | } 20 | 21 | 22 | class SpandrelSRFormerModel(SpandrelUpscaler): 23 | scale = 4 24 | name = "SRFormer" 25 | 26 | def __init__(self, create_dirs=False): 27 | super().__init__(create_dirs) 28 | self.name = "SRFormer" 29 | self.scale = 4 30 | user_models = self.find_models(ext_filter=[".pth"]) 31 | self.scalers = [] 32 | added_models = [] 33 | for file in user_models: 34 | model_name = os.path.basename(file) 35 | display_name = modelloader.friendly_name(file) 36 | for pre_name, model_url in models.items(): 37 | if model_name in model_url: 38 | display_name = pre_name 39 | self.scalers.append(UpscalerData(display_name, file, self)) 40 | added_models.append(display_name) 41 | break 42 | if display_name not in added_models: 43 | self.scalers.append(UpscalerData(display_name, file, self)) 44 | added_models.append(display_name) 45 | for model_name, model_url in models.items(): 46 | if model_name not in added_models: 47 | file_name = os.path.basename(model_url) 48 | model_path = modelloader.load_file_from_url(model_url, model_dir=self.model_path, file_name=file_name) 49 | self.scalers.append(UpscalerData(model_name, model_path, self)) 50 | -------------------------------------------------------------------------------- /upscalers/spandrel/spandrel_upscaler_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | import re 6 | import time 7 | import traceback 8 | import zipfile 9 | from abc import abstractmethod 10 | from pathlib import Path 11 | from urllib.request import urlretrieve 12 | 13 | import PIL 14 | import cv2 15 | import numpy as np 16 | import torch 17 | from PIL import Image 18 | from spandrel import ( 19 | ImageModelDescriptor, 20 | ModelDescriptor, ModelLoader, 21 | ) 22 | 23 | from modules import shared 24 | from modules.upscaler import Upscaler, UpscalerData, NEAREST 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def convert_google_drive_link(url: str) -> str: 30 | pattern = re.compile( 31 | r"^https://drive.google.com/file/d/([a-zA-Z0-9_\-]+)/view(?:\?.*)?$" 32 | ) 33 | m = pattern.match(url) 34 | if not m: 35 | return url 36 | file_id = m.group(1) 37 | return "https://drive.google.com/uc?export=download&confirm=1&id=" + file_id 38 | 39 | 40 | def download_file(url: str, filename: Path | str) -> None: 41 | filename = Path(filename) 42 | filename.parent.mkdir(exist_ok=True) 43 | 44 | url = convert_google_drive_link(url) 45 | 46 | temp_filename = filename.with_suffix(f".part-{int(time.time())}") 47 | 48 | try: 49 | logger.info("Downloading %s to %s", url, filename) 50 | path, _ = urlretrieve(url, filename=temp_filename) 51 | temp_filename.rename(filename) 52 | finally: 53 | try: 54 | temp_filename.unlink() 55 | except FileNotFoundError: 56 | pass 57 | 58 | 59 | def extract_file_from_zip( 60 | zip_path: Path | str, 61 | rel_model_path: str, 62 | filename: Path | str, 63 | ): 64 | filename = Path(filename) 65 | filename.parent.mkdir(exist_ok=True) 66 | 67 | with zipfile.ZipFile(zip_path, "r") as zip_ref: 68 | with open(filename, "wb") as f: 69 | f.write(zip_ref.read(rel_model_path)) 70 | 71 | 72 | def image_to_tensor(img: np.ndarray, device: str, half) -> torch.Tensor: 73 | img = img.astype(np.float32) / 255.0 74 | if img.ndim == 2: 75 | img = np.expand_dims(img, axis=2) 76 | if img.shape[2] == 1: 77 | pass 78 | else: 79 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 80 | img = np.transpose(img, (2, 0, 1)) 81 | tensor = torch.from_numpy(img).to(device) 82 | if half is not None: 83 | tensor = tensor.to(half) 84 | return tensor.unsqueeze(0) 85 | 86 | 87 | def tensor_to_image(tensor: torch.Tensor) -> np.ndarray: 88 | image = tensor.cpu().squeeze().numpy() 89 | image = np.transpose(image, (1, 2, 0)) 90 | image = np.clip((image * 255.0).round(), 0, 255) 91 | image = image.astype(np.uint8) 92 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 93 | return image 94 | 95 | 96 | def image_inference_tensor( 97 | model: ImageModelDescriptor, tensor: torch.Tensor 98 | ) -> torch.Tensor: 99 | model.eval() 100 | with torch.no_grad(): 101 | return model(tensor) 102 | 103 | 104 | def image_inference(model: ImageModelDescriptor, image: np.ndarray, device: str, half: str) -> np.ndarray: 105 | return tensor_to_image(image_inference_tensor(model, image_to_tensor(image, device, half))) 106 | 107 | 108 | def get_h_w_c(image: np.ndarray) -> tuple[int, int, int]: 109 | if len(image.shape) == 2: 110 | return image.shape[0], image.shape[1], 1 111 | return image.shape[0], image.shape[1], image.shape[2] 112 | 113 | 114 | class SpandrelUpscaler(Upscaler): 115 | model_url = "" 116 | model_type = "" 117 | model_file = "" 118 | scale = 4 119 | 120 | def __init__(self, create_dirs=False): 121 | super().__init__(create_dirs) 122 | self.name = "Spandrel" 123 | self.scale = 1 124 | self.scalers = [] 125 | 126 | def do_upscale(self, img: PIL.Image, selected_model: str): 127 | self.load_model(selected_model) 128 | return self.internal_upscale(img) 129 | 130 | def load_model(self, path: str): 131 | self.model = ModelLoader().load_from_file(path) 132 | print(f"Model size reqs: {self.model.size_requirements}") 133 | device = "cuda" if torch.cuda.is_available() else "cpu" 134 | self.model.to(device) 135 | if self.model.supports_half: 136 | self.model.to(torch.half) 137 | self.model.eval() 138 | 139 | def preprocess(self, image: Image) -> Image: 140 | square = self.model.size_requirements.square 141 | minimum = self.model.size_requirements.minimum 142 | multiple_of = self.model.size_requirements.multiple_of 143 | if square: 144 | # Pad the shorter side to make the image square 145 | size = max(image.width, image.height) 146 | new_image = Image.new("RGB", (size, size)) 147 | new_image.paste(image, ((size - image.width) // 2, (size - image.height) // 2)) 148 | image = new_image 149 | if minimum > 1: 150 | size = max(image.width, image.height) 151 | if size < minimum: 152 | new_width = int(minimum * image.width / size) 153 | new_height = int(minimum * image.height / size) 154 | image = image.resize((new_width, new_height), resample=NEAREST) 155 | if multiple_of > 1: 156 | new_width = int(multiple_of * image.width // multiple_of) 157 | new_height = int(multiple_of * image.height // multiple_of) 158 | image = image.resize((new_width, new_height), resample=NEAREST) 159 | return image 160 | 161 | def postprocess(self, image: Image, original_width: int, original_height: int) -> Image: 162 | square = self.model.size_requirements.square 163 | if square: 164 | original_aspect_ratio = original_width / original_height 165 | current_aspect_ratio = image.width / image.height 166 | 167 | if current_aspect_ratio > original_aspect_ratio: 168 | # Image is wider than the original, crop width 169 | new_width = int(original_aspect_ratio * image.height) 170 | left = (image.width - new_width) // 2 171 | image = image.crop((left, 0, left + new_width, image.height)) 172 | elif current_aspect_ratio < original_aspect_ratio: 173 | # Image is taller than the original, crop height 174 | new_height = int(image.width / original_aspect_ratio) 175 | top = (image.height - new_height) // 2 176 | image = image.crop((0, top, image.width, top + new_height)) 177 | 178 | return image 179 | 180 | def internal_upscale(self, image: Image): 181 | original_width = image.width 182 | original_height = image.height 183 | needs_preprocess = self.model.size_requirements.check(image.width, image.height) 184 | if not needs_preprocess: 185 | image = self.preprocess(image) 186 | # Convert image to cv2 format 187 | image = np.array(image) 188 | image_h, image_w, image_c = get_h_w_c(image) 189 | 190 | if self.model.input_channels == 1 and image_c == 3: 191 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 192 | 193 | device = "cuda" if torch.cuda.is_available() else "cpu" 194 | 195 | # If the image size is already greater than 2048, we'll likely OOM on GPU, so do it on CPU 196 | if image_h > 2048 or image_w > 2048: 197 | device = "cpu" 198 | 199 | try: 200 | self.model.to(device) 201 | half = None 202 | if self.model.supports_half: 203 | half = torch.half 204 | output = image_inference(self.model, image, device, half) 205 | # Convert output to PIL format 206 | output = Image.fromarray(output) 207 | if needs_preprocess: 208 | output = self.postprocess(output, original_width, original_height) 209 | return output 210 | except Exception as e: 211 | print(f"Failed to upscale image: {e}") 212 | traceback.print_exc() 213 | return Image.fromarray(image) 214 | 215 | def unload(self): 216 | try: 217 | del self.model 218 | except: 219 | pass 220 | self.model = None 221 | 222 | def load(self): 223 | pass 224 | --------------------------------------------------------------------------------