├── .gitattributes
├── .gitignore
├── README.md
├── characters
├── nsfw.txt
└── sfw.txt
├── clipcrop.py
├── configs
└── upscaler.yaml
├── dbimutils.py
├── file_manager.py
├── install.py
├── interrogators
├── blip_interrogator.py
├── booru_interrogator.py
├── data
│ ├── artists.txt
│ ├── flavors.txt
│ ├── mediums.txt
│ └── movements.txt
├── idefics2_interrogator.py
├── interrogator.py
├── llava2_interrogator.py
├── moondream_interrogator.py
└── wolf_interrogator.py
├── javascript
└── smart_process.js
├── model_download.py
├── mplug_owl2
├── __init__.py
├── constants.py
├── conversation.py
├── evaluate
│ ├── EVALUATION.md
│ ├── __init__.py
│ ├── evaluate_caption.py
│ ├── evaluate_mmbench.py
│ ├── evaluate_mme.py
│ ├── evaluate_mmmu.py
│ ├── evaluate_vqa.py
│ ├── mmbench_converter.py
│ ├── vqa.py
│ └── vqa_eval.py
├── local_serve
│ ├── __init__.py
│ ├── examples
│ │ ├── Rebecca_(1939_poster)_Small.jpeg
│ │ └── extreme_ironing.jpg
│ ├── local_web_server.py
│ └── model_worker.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── builder.py
│ ├── configuration_mplug_owl2.py
│ ├── configuration_qwen.py
│ ├── convert_mplug_owl2_weight_to_hf.py
│ ├── modeling_attn_mask_utils.py
│ ├── modeling_llama2.py
│ ├── modeling_mplug_owl2.py
│ ├── modeling_qwen.py
│ ├── multiway.py
│ ├── utils.py
│ └── visual_encoder.py
├── serve
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── Rebecca_(1939_poster)_Small.jpeg
│ │ └── extreme_ironing.jpg
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ └── register_workers.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── mplug_owl2_trainer.py
│ ├── train.py
│ └── train_mem.py
└── utils.py
├── process_params.py
├── processors.py
├── requirements.txt
├── scripts
└── process_main.py
├── smartprocess.py
├── style.css
├── super_resolution.py
└── upscalers
└── spandrel
├── spandrel_srformer_model.py
└── spandrel_upscaler_base.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
141 | # pytype static type analyzer
142 | .pytype/
143 |
144 | # Cython debug symbols
145 | cython_debug/
146 |
147 | # PyCharm
148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
150 | # and can be added to the global gitignore or merged into this file. For a more nuclear
151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
152 | #.idea/
153 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion WebUI Smart Pre-Processing Extension
2 |
3 | ## What is this??
4 |
5 | As the name would imply, this is an extension for
6 | the [Stable-Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) by @Automatic1111
7 |
8 | ## What does this do?
9 |
10 | It does a few things, actually.
11 |
12 | For starters, it utilizes a combination of BLIP/CLIP and YOLOv5 to provide "smart cropping" for images. The primary
13 | subject of each image is identified, the center of that subject is determined, and then the application tries it's best
14 | to crop the image so as to keep as much of the subject as possible within the dimensions specified.
15 |
16 | Second, it allows storing the determined image caption directly to the image filename, versus having to create a txt
17 | file along side every image. You can still create a txt file, use existing captions, or not do any captioning at all.
18 |
19 | Third, I've provided face restoration and upscaling options for input images. You can select from GFPGAN and Codeformer
20 | for face restoration, and any of the provided upscalers from the "extras' tab to refine/smooth/add detail to your final
21 | output images.
22 |
23 | Last, but not least, it offers a rudimentary way to swap the "class" of a captioned image with the specific keyword in
24 | the image. So, if you're trying to train a subject called "xyz" and "xyz" is a dog, you can easily swap "dog" (and "a
25 | dog") wth "xyz" in your captions. Neato!
26 |
27 | ## Smart Cropping
28 |
29 | As I said above, smart cropping utilizes a combination of YOLOV5 object recognition and BLIP/CLIP (and DeepDanBooru)
30 | captioning to automatically determine the most prominent subject in a photo, and automatically crop the subject as
31 | completely as possible. You can also specify a specific subject (dog/cat/woman/house) for the software to find, and skip
32 | the YOLOV5 detection entirely.
33 |
34 |
35 |
36 | If a subject is not found, the image will be downscaled and cropped from the center.
37 |
38 | ## Smart Captioning
39 |
40 | This uses all the same features as set in user preferences, with the additional options to save to txt or append to the
41 | image file name.
42 |
43 | Additionally, you can swap the generic "class" of the image with a specific subject keyword. This feature may not be
44 | perfect in all cases, but it should still go a long way in speeding up the captioning process.
45 |
46 | You can also specify a maximum caption length, which will split the caption by spaces and append words until the maximum
47 | length is reached.
48 |
49 | ## Post Processing
50 |
51 | It's basically a simplified version of the "extras" tab. The idea is that you can do facial restoration and/or use a
52 | model like swinIR or LDSR to smooth or add details to an image. If an image is "actually" upscaled beyond the target
53 | crop size, it will be downscaled again back to the original size.
54 |
55 |
--------------------------------------------------------------------------------
/characters/nsfw.txt:
--------------------------------------------------------------------------------
1 | Human:
2 | Human: Give a short one sentence caption for this image with NO punctuation. DO NOT USE ANY PUNCTUATION OR COMMAS. DO NOT USE COMMAS!!!
3 | AI:
--------------------------------------------------------------------------------
/characters/sfw.txt:
--------------------------------------------------------------------------------
1 | Revised and SFW prompt:
2 |
3 | Act like a stable diffusion SD prompt generator, your focus is to convey warm and playful scenarios in vivid detail, include lush descriptions, the quality of your prompts must be only the best creative and varied. Your subjects must be adults only. When you see "young" in a prompt, assume the subject is in their early 20s. The SD prompts must be formatted in a way the AI can understand.
4 |
5 | Stable Diffusion prompts are made of components which are comprised of keywords separated by commas, keywords can be single words or multi-word keywords, and they have a specific order. A typical format for the components looks like this: [Adjectives], [Type], [Framing], [Shot], [subject], [Descriptor], [Descriptor], [Descriptor], [Expression], [Pose], [Action], [Environment], [Details], [Lighting], [Medium], [Aesthetics], [Visual], [Clothing].
6 |
7 | Here are some keywords commonly used for each of the components, but always mix them with new ones that are coherent to each component:
8 |
9 | Adjectives: Masterpiece, award-winning, Most beautiful, Epic, Majestic, best quality, fantastic, highly detailed, ultra-detailed, 8K, realistic
10 | Type: photo, selfie, 8k photo, professional photo
11 | Framing: Dutch angle, Wide Angle, low angle, high angle, perspective, on eye level, POV, close-up
12 | Shot: mid shot, full shot, portrait, from behind, long shot, cowboy shot, from above, from below
13 | Subject: woman, man
14 | Descriptor: elegant, well-dressed, casual attire, formal attire, stylish, modern, vintage, sporty, artistic, natural, sophisticated, trendy, chic
15 | Expression: satisfied, happy, excited, surprised, mouth open, eyes closed, sleeping, smiling, frowning, crying
16 | Pose: seated, standing, laying down, hands raised, walking, jumping, dancing, running
17 | Action: laughing, chatting, cooking, reading, painting, dancing, sleeping, crying, sitting, standing, lying down
18 | Environment: park, cafe, library, office, beach, mall, rooftop, forest, kitchen, bedroom, outer space, mountains, cityscape
19 | Details: Cloudless sky, glittering night, sparkling rain, shining lights, obscure darkness, red lights, neon lighting
20 | Lighting: light, dim light, two-tone lighting, dynamic lighting, rim light, studio light, diffuse lighting, dark lighting, soft lighting, fantasy lighting, natural light
21 | Medium: photograph, painting, illustration, sketch, render
22 | Aesthetics: vintage, retro, modern, rustic, minimalist, bohemian, industrial, romantic, fantasy, sci-fi, noir, gothic, ethereal
23 | Visual: contrast, cyan hue, Kodachrome, Fujifilm Superia, warm colours, saturation, vibrance, filters coolness, chromatic aberration, cinematic, RAW color
24 | Clothing: casual, formal, vintage, modern, sporty, chic, elegant, streetwear, boho, classic, punk, gothic, trendy
25 |
26 | Use the components to build coherent prompts. Use these keywords but also create your own. Generate variations of the keywords that are coherent to each component and fit the instruction. Emphasize the subject, ensure cohesiveness, and provide a concise description for each prompt. Only reply with full single prompts separated by a line break, do not add a numbered list, quotes, or a section breakdown.
27 |
28 | For example:
29 |
30 | [a photo of a beautiful woman standing outside at night, masterpiece, professional photo, Dutch angle, full shot, woman, elegant, stylish, happy, standing, laughing, park, glittering night, light, photograph, vintage, contrast, casual]
31 |
32 | Human:
33 | Human: Caption this image.
34 | AI:
--------------------------------------------------------------------------------
/clipcrop.py:
--------------------------------------------------------------------------------
1 | # Original project: https://github.com/Vishnunkumar/clipcrop/blob/main/clipcrop/clipcrop.py
2 | import os.path
3 | import sys
4 |
5 | import cv2
6 | import numpy
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | from clip import clip
11 |
12 | import modules.paths
13 | from modules import shared, modelloader
14 |
15 |
16 | def clip_boxes(boxes, shape):
17 | # Clip boxes (xyxy) to image shape (height, width)
18 | if isinstance(boxes, torch.Tensor): # faster individually
19 | boxes[:, 0].clamp_(0, shape[1]) # x1
20 | boxes[:, 1].clamp_(0, shape[0]) # y1
21 | boxes[:, 2].clamp_(0, shape[1]) # x2
22 | boxes[:, 3].clamp_(0, shape[0]) # y2
23 | else: # np.array (faster grouped)
24 | boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
25 | boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
26 |
27 |
28 | def find_position(parent: Image, child: Image):
29 | w = child.width
30 | h = child.height
31 | res = cv2.matchTemplate(np.array(parent), np.array(child), cv2.TM_CCOEFF_NORMED)
32 | min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
33 | # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum
34 | top_left = max_loc
35 | center_x = top_left[0] + (w / 2)
36 | center_y = top_left[1] + (h / 2)
37 | return center_x, center_y
38 |
39 |
40 | class CropClip:
41 | def __init__(self):
42 | # Model
43 | model_name = 'yolov5m6v7.pt'
44 | model_url = 'https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m6.pt'
45 | model_dir = os.path.join(modules.paths.models_path, "yolo")
46 | model_path = modelloader.load_models(model_dir, model_url, None, '.pt', model_name)
47 | was_safe_unpickle = shared.cmd_opts.disable_safe_unpickle
48 | shared.cmd_opts.disable_safe_unpickle = True
49 | self.model = torch.hub.load('ultralytics/yolov5', 'custom', model_path[0])
50 | self.clip = None
51 | self.preprocess = None
52 | shared.cmd_opts.disable_safe_unpickle = was_safe_unpickle
53 | # Prevent BLIP crossfire breakage
54 | try:
55 | del sys.modules['models']
56 | except:
57 | pass
58 |
59 | def get_center(self, image: Image, prompt: str):
60 | # Load image into YOLO parser
61 | results = self.model(image) # includes NMS
62 | # Crop each image result to an array
63 | cropped = results.crop(False)
64 | l = []
65 | for crop in cropped:
66 | l.append(Image.fromarray(crop["im"]))
67 | if len(l) == 0:
68 | l = [image]
69 | device = shared.device
70 | # Take out cropped YOLO images, and get the features?
71 | if not self.model or not self.preprocess:
72 | self.clip, self.preprocess = clip.load("ViT-B/32", device=device)
73 | images = torch.stack([self.preprocess(im) for im in l]).to(device)
74 | with torch.no_grad():
75 | image_features = self.clip.encode_image(images)
76 | image_features /= image_features.norm(dim=-1, keepdim=True)
77 |
78 | image_features.cpu().numpy()
79 | image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
80 | image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
81 |
82 | images = [self.preprocess(im) for im in l]
83 | image_input = torch.tensor(np.stack(images)).cuda()
84 | image_input -= image_mean[:, None, None]
85 | image_input /= image_std[:, None, None]
86 | with torch.no_grad():
87 | image_features = self.clip.encode_image(image_input).float()
88 | image_features /= image_features.norm(dim=-1, keepdim=True)
89 |
90 | def similarity_top(similarity_list, N):
91 | results = zip(range(len(similarity_list)), similarity_list)
92 | results = sorted(results, key=lambda x: x[1], reverse=True)
93 | top_images = []
94 | scores = []
95 | for index, score in results[:N]:
96 | scores.append(score)
97 | top_images.append(l[index])
98 | return scores, top_images
99 |
100 | # @title Crop
101 | with torch.no_grad():
102 | # Encode and normalize the description using CLIP
103 | text_encoded = self.clip.encode_text(clip.tokenize(prompt).to(device))
104 | text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
105 |
106 | # Retrieve the description vector and the photo vectors
107 | similarity = text_encoded.cpu().numpy() @ image_features.cpu().numpy().T
108 | similarity = similarity[0]
109 | scores, imgs = similarity_top(similarity, N=3)
110 | max_area = 0
111 | out = None
112 | for img in imgs:
113 | img_area = img.width * img.height
114 | if img_area > max_area:
115 | max_area = img_area
116 | out = img
117 |
118 | if not out:
119 | out = image
120 | res = cv2.matchTemplate(numpy.array(image), numpy.array(out), cv2.TM_SQDIFF)
121 | min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
122 | # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum
123 | top_left = min_loc
124 | bottom_right = (top_left[0] + out.width, top_left[1] + out.height)
125 | return [top_left[0], bottom_right[0], top_left[1], bottom_right[1]]
126 |
127 | def unload(self):
128 | if self.model is not None:
129 | self.model.to('cpu')
130 | if self.clip:
131 | self.clip.to('cpu')
132 | if self.preprocess:
133 | self.preprocess.to('cpu')
134 |
135 | def load(self):
136 | if self.model is not None:
137 | self.model.to(shared.device)
138 | if self.clip:
139 | self.clip.to(shared.device)
140 | if self.preprocess:
141 | self.preprocess.to(shared.device)
142 |
--------------------------------------------------------------------------------
/configs/upscaler.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
4 | params:
5 | parameterization: "v"
6 | low_scale_key: "lr"
7 | linear_start: 0.0001
8 | linear_end: 0.02
9 | num_timesteps_cond: 1
10 | log_every_t: 200
11 | timesteps: 1000
12 | first_stage_key: "jpg"
13 | cond_stage_key: "txt"
14 | image_size: 128
15 | channels: 4
16 | cond_stage_trainable: false
17 | conditioning_key: "hybrid-adm"
18 | monitor: val/loss_simple_ema
19 | scale_factor: 0.08333
20 | use_ema: False
21 |
22 | low_scale_config:
23 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
24 | params:
25 | noise_schedule_config: # image space
26 | linear_start: 0.0001
27 | linear_end: 0.02
28 | max_noise_level: 350
29 |
30 | unet_config:
31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32 | params:
33 | use_checkpoint: True
34 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
35 | image_size: 128
36 | in_channels: 7
37 | out_channels: 4
38 | model_channels: 256
39 | attention_resolutions: [ 2,4,8]
40 | num_res_blocks: 2
41 | channel_mult: [ 1, 2, 2, 4]
42 | disable_self_attentions: [True, True, True, False]
43 | disable_middle_self_attn: False
44 | num_heads: 8
45 | use_spatial_transformer: True
46 | transformer_depth: 1
47 | context_dim: 1024
48 | legacy: False
49 | use_linear_in_transformer: True
50 |
51 | first_stage_config:
52 | target: ldm.models.autoencoder.AutoencoderKL
53 | params:
54 | embed_dim: 4
55 | ddconfig:
56 | # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
57 | double_z: True
58 | z_channels: 4
59 | resolution: 256
60 | in_channels: 3
61 | out_ch: 3
62 | ch: 128
63 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
64 | num_res_blocks: 2
65 | attn_resolutions: [ ]
66 | dropout: 0.0
67 |
68 | lossconfig:
69 | target: torch.nn.Identity
70 |
71 | cond_stage_config:
72 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
73 | params:
74 | freeze: True
75 | layer: "penultimate"
76 |
--------------------------------------------------------------------------------
/dbimutils.py:
--------------------------------------------------------------------------------
1 | # DanBooru IMage Utility functions, borrowed from
2 | # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/dbimutils.py
3 |
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 |
8 |
9 | def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
10 | if img.endswith(".gif"):
11 | img = Image.open(img)
12 | img = img.convert("RGB")
13 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
14 | else:
15 | img = cv2.imread(img, flag)
16 | return img
17 |
18 |
19 | def smart_24bit(img):
20 | if img.dtype is np.dtype(np.uint16):
21 | img = (img / 257).astype(np.uint8)
22 |
23 | if len(img.shape) == 2:
24 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
25 | elif img.shape[2] == 4:
26 | trans_mask = img[:, :, 3] == 0
27 | img[trans_mask] = [255, 255, 255, 255]
28 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
29 | return img
30 |
31 |
32 | def make_square(img, target_size):
33 | old_size = img.shape[:2]
34 | desired_size = max(old_size)
35 | desired_size = max(desired_size, target_size)
36 |
37 | delta_w = desired_size - old_size[1]
38 | delta_h = desired_size - old_size[0]
39 | top, bottom = delta_h // 2, delta_h - (delta_h // 2)
40 | left, right = delta_w // 2, delta_w - (delta_w // 2)
41 |
42 | color = [255, 255, 255]
43 | new_im = cv2.copyMakeBorder(
44 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
45 | )
46 | return new_im
47 |
48 |
49 | def smart_resize(img, size):
50 | # Assumes the image has already gone through make_square
51 | if img.shape[0] > size:
52 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
53 | elif img.shape[0] < size:
54 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
55 | return img
56 |
--------------------------------------------------------------------------------
/file_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from typing import List
4 |
5 | import PIL
6 | from PIL.Image import Image
7 |
8 |
9 |
10 | def clean_string(s):
11 | """
12 | Remove non-alphanumeric characters except spaces, and normalize spacing.
13 | Args:
14 | s: The string to clean.
15 |
16 | Returns: A cleaned string.
17 | """
18 | # Remove non-alphanumeric characters except spaces
19 | cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', s)
20 | # Check for a sentence with just the same word repeated
21 | if len(set(cleaned.split())) == 1:
22 | cleaned = cleaned.split()[0]
23 | # Replace multiple spaces with a single space
24 | cleaned = re.sub(r'\s+', ' ', cleaned).strip()
25 | return cleaned
26 |
27 |
28 | class ImageData:
29 | image_path: str = ""
30 | temp_image_path: str = ""
31 | caption: str = ""
32 | tags: List[str] = []
33 | selected: bool = False
34 | filtered: bool = False
35 | id = None
36 |
37 | def __init__(self, image_path):
38 | self.image_path = image_path
39 | self.caption = self.read_caption()
40 | self.tags = self.split_caption()
41 | # Generate a random id
42 | self.id = os.urandom(32).hex()
43 |
44 | def read_caption(self):
45 | existing_caption_txt_filename = os.path.splitext(self.image_path)[0] + '.txt'
46 | if os.path.exists(existing_caption_txt_filename):
47 | with open(existing_caption_txt_filename, 'r', encoding="utf8") as file:
48 | existing_caption_txt = file.read()
49 | existing_caption_txt = existing_caption_txt.strip()
50 | else:
51 | image_name = os.path.splitext(os.path.basename(self.image_path))[0]
52 | existing_caption_txt = clean_string(image_name)
53 | return existing_caption_txt
54 |
55 | def split_caption(self):
56 | tags = self.caption.split(",")
57 | tags = [tag.strip() for tag in tags]
58 | tags = [tag for tag in tags if tag != ""]
59 | return tags
60 |
61 | def update_image(self, image: Image, save_file: bool = False):
62 | if save_file:
63 | img_path = os.path.splitext(self.image_path)[0] + '.png'
64 | if img_path != self.image_path and os.path.exists(self.image_path):
65 | os.remove(self.image_path)
66 | self.image_path = img_path
67 | else:
68 | self.temp_image_path = os.path.splitext(self.image_path)[0] + '_temp.png'
69 | img_path = self.temp_image_path
70 | image.save(img_path)
71 |
72 | def update_caption(self, caption: str, save_file: bool = False):
73 | if save_file:
74 | caption_txt_filename = os.path.splitext(self.image_path)[0] + '.txt'
75 | with open(caption_txt_filename, 'w', encoding="utf8") as file:
76 | file.write(caption)
77 | self.caption = caption
78 | self.tags = self.split_caption()
79 |
80 | def get_image(self):
81 | return PIL.Image.open(self.image_path).convert("RGB")
82 |
83 |
84 | class FileManager:
85 | file_path: str = ""
86 | _instance = None
87 | files: List[ImageData] = []
88 | included_tags: List[str] = []
89 | excluded_tags: List[str] = []
90 | included_strings: List[str] = []
91 | excluded_strings: List[str] = []
92 | current_image = None
93 |
94 | def __init__(self):
95 | self.files = []
96 |
97 | def __new__(cls):
98 | if FileManager._instance is None:
99 | FileManager._instance = object.__new__(cls)
100 | return FileManager._instance
101 |
102 | def clear(self):
103 | self.files = []
104 | self.included_tags = []
105 | self.excluded_tags = []
106 | self.included_strings = []
107 | self.excluded_strings = []
108 | self.current_image = None
109 |
110 | def load_files(self):
111 | from extensions.sd_smartprocess.smartprocess import is_image
112 |
113 | self.clear()
114 | # Walk through all files in the directory that contains the images
115 | for root, dirs, files in os.walk(self.file_path):
116 | for file in files:
117 | file = os.path.join(root, file)
118 | if is_image(file) and "_backup" not in file:
119 | image_data = ImageData(file)
120 | self.files.append(image_data)
121 | self.update_filters()
122 |
123 | def filtered_files(self, for_gallery: bool = False):
124 | if not for_gallery:
125 | return [file for file in self.files if file.filtered]
126 | else:
127 | return [(file.image_path, file.caption) for file in self.files if file.filtered]
128 |
129 | def all_files(self, for_gallery: bool = False) -> List[ImageData]:
130 | if not for_gallery:
131 | return self.files
132 | else:
133 | return [(file.image_path, file.caption) for file in self.files]
134 |
135 | def selected_files(self, for_gallery: bool = False) -> List[ImageData]:
136 | if not for_gallery:
137 | return [file for file in self.files if file.selected]
138 | else:
139 | return [(file.image_path, file.caption) for file in self.files if file.selected]
140 |
141 | def filtered_and_selected_files(self, for_gallery: bool = False) -> List[ImageData]:
142 | if not for_gallery:
143 | return [file for file in self.files if file.filtered and file.selected and file.filtered]
144 | else:
145 | return [(file.image_path, file.caption) for file in self.files if file.selected and file.filtered]
146 |
147 | def update_files(self, files: List[ImageData]):
148 | for file in files:
149 | self.update_file(file)
150 |
151 | def update_file(self, file: ImageData):
152 | # Search for the file with the same ID and update it if found
153 | for i, existing_file in enumerate(self.files):
154 | if existing_file.id == file.id:
155 | self.files[i] = file
156 | break
157 | else:
158 | # The file was not found in the list, append it
159 | self.files.append(file)
160 |
161 | def update_filters(self):
162 | """
163 | Filters a collection of files based on specified inclusion and exclusion criteria for tags and filter strings.
164 |
165 | The function filters files based on tags extracted from file captions and matches these tags against specified
166 | inclusion and exclusion criteria. The criteria can be a set of plain tags, wildcard patterns, or regex expressions.
167 |
168 | Parameters:
169 | - use_all_files (bool): Determines whether to filter from all files or only selected files.
170 | If True, filters from all files; otherwise, filters from selected files.
171 |
172 | Globals:
173 | - filter_tags_include (list): Tags required to include a file.
174 | - filter_tags_exclude (list): Tags that lead to exclusion of a file.
175 | - filter_string_include (list): Filter strings for inclusion; supports wildcards and regex.
176 | - filter_string_exclude (list): Filter strings for exclusion; supports wildcards and regex.
177 | - all_files (list): List of all files, where each file is a tuple (filename, caption).
178 | - selected_files (list): List of selected files, where each file is a tuple (filename, caption).
179 |
180 | Returns:
181 | - filtered_files (list): List of files filtered based on the specified criteria.
182 |
183 | Examples:
184 | 1. Plain Tag Matching:
185 | - To include files with a tag 'holiday', add 'holiday' to filter_tags_include.
186 | - To exclude files with a tag 'work', add 'work' to filter_tags_exclude.
187 |
188 | 2. Wildcard Patterns:
189 | - To include files with tags starting with 'trip', add 'trip*' to filter_string_include.
190 | - To exclude files with tags ending with '2023', add '*2023' to filter_string_exclude.
191 |
192 | 3. Regex Expressions:
193 | - To include files with tags that have any number followed by 'days', add '\\d+days' to filter_string_include.
194 | - To exclude files with tags formatted like dates (e.g., '2023-04-01'), add '\\d{4}-\\d{2}-\\d{2}' to filter_string_exclude.
195 |
196 | Note:
197 | - The function treats tags as case-sensitive.
198 | - Wildcard '*' matches any sequence of characters (including none).
199 | - Regex patterns should follow Python's 're' module syntax.
200 | """
201 |
202 | def matches_pattern(pattern, string):
203 | # Convert wildcard to regex if necessary
204 | if '*' in pattern:
205 | pattern = '^' + pattern.replace('*', '.*') + '$'
206 | return re.match(pattern, string) is not None
207 | else:
208 | if " " not in pattern:
209 | parts = string.split(" ")
210 | return pattern in parts
211 | else:
212 | return pattern in string
213 |
214 | def should_include(tag, filter_tags, filter_strings):
215 | if len(filter_tags) == 0 and len(filter_strings) == 0:
216 | return True
217 | tag_match = False
218 | filter_match = False
219 | if tag in filter_tags and len(filter_tags) > 0:
220 | tag_match = True
221 | if len(filter_strings) > 0:
222 | for filter_string in filter_strings:
223 | if matches_pattern(filter_string, tag):
224 | filter_match = True
225 | break
226 | return tag_match or filter_match
227 |
228 | def should_exclude(tag, filter_tags, filter_strings):
229 | if tag in filter_tags:
230 | return True
231 | for filter_string in filter_strings:
232 | if matches_pattern(filter_string, tag):
233 | return True
234 | return False
235 |
236 | files = self.files
237 | updated_files = []
238 | for file in files:
239 | tags = file.tags
240 | out_tags = []
241 | for tag in tags:
242 | include = True
243 | if should_exclude(tag, self.excluded_tags, self.excluded_strings):
244 | include = False
245 | elif not should_include(tag, self.included_tags, self.included_strings):
246 | include = False
247 | if include:
248 | out_tags.append(tag)
249 | file.filtered = len(out_tags) > 0
250 | updated_files.append(file)
251 | self.files = files
252 |
253 | def all_captions(self):
254 | return [file.caption for file in self.files]
255 |
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from launch import run, git_clone, repo_dir
5 |
6 | name = "Smart Crop"
7 | req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
8 | print(f"loading {name} reqs from {req_file}")
9 | run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements.",
10 | f"Couldn't install {name} requirements.")
11 | print("Cloning BLIP repo...")
12 | blip_repo = "https://github.com/pharmapsychotic/BLIP"
13 | git_clone(blip_repo, repo_dir('BLIP'), "BLIP")
14 |
15 | # git clone https://github.com/X-PLUG/mPLUG-Owl.git
16 | # cd mPLUG-Owl/mPLUG-Owl2
17 | # mplug_repo = "https://github.com/X-PLUG/mPLUG-Owl.git"
18 | # mplug_owl_path = repo_dir('mplug_owl_src')
19 | # git_clone(mplug_repo, mplug_owl_path, "mPLUG-Owl")
20 | # mplug_owl2_sub1_path = os.path.join(mplug_owl_path, "mPLUG-Owl", "mplug_owl")
21 | # mplug_owl2_sub2_path = os.path.join(mplug_owl_path, "mPLUG-Owl2", "mplug_owl2")
22 | #
23 | # sys.path.append(mplug_owl2_sub1_path)
24 | # sys.path.append(mplug_owl2_sub2_path)
25 | #
26 |
--------------------------------------------------------------------------------
/interrogators/blip_interrogator.py:
--------------------------------------------------------------------------------
1 | from PIL.Image import Image
2 |
3 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
4 | from transformers import AutoProcessor, Blip2ForConditionalGeneration
5 | import torch
6 |
7 | from extensions.sd_smartprocess.model_download import fetch_model
8 | from extensions.sd_smartprocess.process_params import ProcessParams
9 |
10 |
11 | class BLIPInterrogator(Interrogator):
12 | _instance = None # Class variable to hold the singleton instance
13 | params = {"max_tokens": 75, "load_in_8bit": False}
14 |
15 | def __init__(self, params: ProcessParams):
16 | super().__init__(params)
17 | self.processor = None
18 | self.model = None
19 | self.current_device = None
20 | self.load_8bit = params.load_in_8bit
21 |
22 | def __new__(cls, params: ProcessParams):
23 | if cls._instance is None:
24 | cls._instance = super(BLIPInterrogator, cls).__new__(cls)
25 | try:
26 | cls._instance._init(params)
27 | except Exception as e:
28 | cls._instance = None
29 | raise e
30 | cls.initial_prompt = params.blip_initial_prompt
31 | cls._load_8bit = params.load_in_8bit
32 | return cls._instance
33 |
34 | def _init(self, params: ProcessParams):
35 | super().__init__(params)
36 | self.model = None
37 | self.processor = None
38 | self.load_8bit = params.load_in_8bit
39 |
40 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
41 | self.load()
42 | self.params = params
43 | max_tokens = params.max_tokens
44 | inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
45 | generated_ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
46 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
47 | if unload:
48 | self.unload()
49 | return generated_text
50 |
51 | def unload(self):
52 | if self.current_device != "cpu":
53 | try:
54 | self.model.to("cpu")
55 | self.current_device = "cpu"
56 | except:
57 | pass
58 |
59 | def load(self):
60 | try:
61 | if self.model is None:
62 | model_path = fetch_model('Salesforce/blip2-opt-6.7b', "blip2")
63 | self.processor = AutoProcessor.from_pretrained(model_path)
64 | self.model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, load_in_8bit=self.load_8bit)
65 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
66 | if self.device != self.current_device:
67 | self.model.to(self.device)
68 | self.current_device = self.device
69 | except:
70 | pass
71 |
--------------------------------------------------------------------------------
/interrogators/booru_interrogator.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import torch
5 | from PIL.Image import Image
6 |
7 | import modules.deepbooru
8 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator, re_special
9 | from extensions.sd_smartprocess.process_params import ProcessParams
10 | from modules import images, devices
11 |
12 |
13 | class BooruInterrogator(Interrogator):
14 | params = {"min_score": 0.75}
15 |
16 | def __init__(self, params: ProcessParams) -> None:
17 | super().__init__(params)
18 | self.tags = None
19 | self.booru = modules.deepbooru.DeepDanbooru()
20 | self.model = self.booru.model
21 |
22 | def unload(self):
23 | self.booru.stop()
24 |
25 | def load(self):
26 | self.booru.start()
27 | pass
28 |
29 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
30 | self.load()
31 | self.params = params
32 | min_score = params.booru_min_score
33 | pic = images.resize_image(2, image, 512, 512)
34 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
35 |
36 | with torch.no_grad(), devices.autocast():
37 | # Move the model to the same device as the input tensor
38 | self.model.to(devices.device)
39 | # Convert input to the correct type (half-precision if needed)
40 | x = torch.from_numpy(a).to(devices.device).type_as(self.model.n_Conv_0.weight)
41 | # Forward pass through the model
42 | y = self.model(x)[0].detach().cpu().numpy()
43 |
44 | probability_dict = {}
45 |
46 | for tag, probability in zip(self.model.tags, y):
47 | if tag.startswith("rating:"):
48 | continue
49 |
50 | probability_dict[tag] = probability
51 |
52 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
53 |
54 | output = {}
55 | for tag in tags:
56 | probability = probability_dict[tag]
57 | tag_outformat = tag
58 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
59 | output[tag_outformat] = probability
60 | out_tags = []
61 | for tag in sorted(output, key=output.get, reverse=True):
62 | if output[tag] >= min_score:
63 | # print(f"DBTag {tag} score is {tags[tag]}")
64 | out_tags.append(tag)
65 | output = ", ".join(out_tags)
66 | if unload:
67 | self.unload()
68 | return output
69 |
--------------------------------------------------------------------------------
/interrogators/data/mediums.txt:
--------------------------------------------------------------------------------
1 | a 3D render
2 | a black and white photo
3 | a bronze sculpture
4 | a cartoon
5 | a cave painting
6 | a character portrait
7 | a charcoal drawing
8 | a child's drawing
9 | a color pencil sketch
10 | a colorized photo
11 | a comic book panel
12 | a computer rendering
13 | a cross stitch
14 | a cubist painting
15 | a detailed drawing
16 | a detailed matte painting
17 | a detailed painting
18 | a diagram
19 | a digital painting
20 | a digital rendering
21 | a drawing
22 | a fine art painting
23 | a flemish Baroque
24 | a gouache
25 | a hologram
26 | a hyperrealistic painting
27 | a jigsaw puzzle
28 | a low poly render
29 | a macro photograph
30 | a manga drawing
31 | a marble sculpture
32 | a matte painting
33 | a microscopic photo
34 | a mid-nineteenth century engraving
35 | a minimalist painting
36 | a mosaic
37 | a painting
38 | a pastel
39 | a pencil sketch
40 | a photo
41 | a photocopy
42 | a photorealistic painting
43 | a picture
44 | a pointillism painting
45 | a polaroid photo
46 | a pop art painting
47 | a portrait
48 | a poster
49 | a raytraced image
50 | a renaissance painting
51 | a screenprint
52 | a screenshot
53 | a silk screen
54 | a sketch
55 | a statue
56 | a still life
57 | a stipple
58 | a stock photo
59 | a storybook illustration
60 | a surrealist painting
61 | a surrealist sculpture
62 | a tattoo
63 | a tilt shift photo
64 | a watercolor painting
65 | a wireframe diagram
66 | a woodcut
67 | an abstract drawing
68 | an abstract painting
69 | an abstract sculpture
70 | an acrylic painting
71 | an airbrush painting
72 | an album cover
73 | an ambient occlusion render
74 | an anime drawing
75 | an art deco painting
76 | an art deco sculpture
77 | an engraving
78 | an etching
79 | an illustration of
80 | an impressionist painting
81 | an ink drawing
82 | an oil on canvas painting
83 | an oil painting
84 | an ultrafine detailed painting
85 | chalk art
86 | computer graphics
87 | concept art
88 | cyberpunk art
89 | digital art
90 | egyptian art
91 | graffiti art
92 | lineart
93 | pixel art
94 | poster art
95 | vector art
--------------------------------------------------------------------------------
/interrogators/data/movements.txt:
--------------------------------------------------------------------------------
1 | abstract art
2 | abstract expressionism
3 | abstract illusionism
4 | academic art
5 | action painting
6 | aestheticism
7 | afrofuturism
8 | altermodern
9 | american barbizon school
10 | american impressionism
11 | american realism
12 | american romanticism
13 | american scene painting
14 | analytical art
15 | antipodeans
16 | arabesque
17 | arbeitsrat für kunst
18 | art & language
19 | art brut
20 | art deco
21 | art informel
22 | art nouveau
23 | art photography
24 | arte povera
25 | arts and crafts movement
26 | ascii art
27 | ashcan school
28 | assemblage
29 | australian tonalism
30 | auto-destructive art
31 | barbizon school
32 | baroque
33 | bauhaus
34 | bengal school of art
35 | berlin secession
36 | black arts movement
37 | brutalism
38 | classical realism
39 | cloisonnism
40 | cobra
41 | color field
42 | computer art
43 | conceptual art
44 | concrete art
45 | constructivism
46 | context art
47 | crayon art
48 | crystal cubism
49 | cubism
50 | cubo-futurism
51 | cynical realism
52 | dada
53 | danube school
54 | dau-al-set
55 | de stijl
56 | deconstructivism
57 | digital art
58 | ecological art
59 | environmental art
60 | excessivism
61 | expressionism
62 | fantastic realism
63 | fantasy art
64 | fauvism
65 | feminist art
66 | figuration libre
67 | figurative art
68 | figurativism
69 | fine art
70 | fluxus
71 | folk art
72 | funk art
73 | furry art
74 | futurism
75 | generative art
76 | geometric abstract art
77 | german romanticism
78 | gothic art
79 | graffiti
80 | gutai group
81 | happening
82 | harlem renaissance
83 | heidelberg school
84 | holography
85 | hudson river school
86 | hurufiyya
87 | hypermodernism
88 | hyperrealism
89 | impressionism
90 | incoherents
91 | institutional critique
92 | interactive art
93 | international gothic
94 | international typographic style
95 | kinetic art
96 | kinetic pointillism
97 | kitsch movement
98 | land art
99 | les automatistes
100 | les nabis
101 | letterism
102 | light and space
103 | lowbrow
104 | lyco art
105 | lyrical abstraction
106 | magic realism
107 | magical realism
108 | mail art
109 | mannerism
110 | massurrealism
111 | maximalism
112 | metaphysical painting
113 | mingei
114 | minimalism
115 | modern european ink painting
116 | modernism
117 | modular constructivism
118 | naive art
119 | naturalism
120 | neo-dada
121 | neo-expressionism
122 | neo-fauvism
123 | neo-figurative
124 | neo-primitivism
125 | neo-romanticism
126 | neoclassicism
127 | neogeo
128 | neoism
129 | neoplasticism
130 | net art
131 | new objectivity
132 | new sculpture
133 | northwest school
134 | nuclear art
135 | objective abstraction
136 | op art
137 | optical illusion
138 | orphism
139 | panfuturism
140 | paris school
141 | photorealism
142 | pixel art
143 | plasticien
144 | plein air
145 | pointillism
146 | pop art
147 | pop surrealism
148 | post-impressionism
149 | postminimalism
150 | pre-raphaelitism
151 | precisionism
152 | primitivism
153 | private press
154 | process art
155 | psychedelic art
156 | purism
157 | qajar art
158 | quito school
159 | rasquache
160 | rayonism
161 | realism
162 | regionalism
163 | remodernism
164 | renaissance
165 | retrofuturism
166 | rococo
167 | romanesque
168 | romanticism
169 | samikshavad
170 | serial art
171 | shin hanga
172 | shock art
173 | socialist realism
174 | sots art
175 | space art
176 | street art
177 | stuckism
178 | sumatraism
179 | superflat
180 | suprematism
181 | surrealism
182 | symbolism
183 | synchromism
184 | synthetism
185 | sōsaku hanga
186 | tachisme
187 | temporary art
188 | tonalism
189 | toyism
190 | transgressive art
191 | ukiyo-e
192 | underground comix
193 | unilalianism
194 | vancouver school
195 | vanitas
196 | verdadism
197 | video art
198 | viennese actionism
199 | visual art
200 | vorticism
--------------------------------------------------------------------------------
/interrogators/idefics2_interrogator.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 | from datetime import datetime
4 |
5 | import torch
6 | from PIL import Image
7 | from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
8 |
9 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
10 | from extensions.sd_smartprocess.model_download import fetch_model
11 | from extensions.sd_smartprocess.mplug_owl2.mm_utils import get_model_name_from_path
12 | from extensions.sd_smartprocess.process_params import ProcessParams
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | NO_CAP_PROMPT = """
18 | Generate a concise caption describing the key elements and context of the image in one sentence,
19 | focusing on the medium, subject, style, clothing, pose, action, and location. Ensure the sentence is accurate and devoid
20 | of assumptions, prose, etc. Keep it direct and relevant to the image.
21 |
22 | Follow the caption with a list of specific tags (keywords) detailing smaller key elements like clothing, poses,
23 | actions, and other notable features.
24 | """
25 |
26 | EX_CAP_PROMPT = """
27 | Here is a caption consisting of a sentence and a list of tags (keywords) that describe the image.
28 |
29 | Refine the provided caption to more accurately and vividly capture the essence and key details visible in the image,
30 | focusing on encapsulating the medium, subject, style, clothing, pose, action, and location in one sentence.
31 |
32 | Update the accompanying tags to reflect only the elements present in the image, ensuring precision and relevance.
33 | Avoid adding new information not supported by the existing caption or the image.
34 | """
35 |
36 |
37 | class Idefics2Interrogator(Interrogator):
38 | model = None
39 | processor = None
40 | params = {"max_tokens": 75, "load_in_4bit": True, "replace_blip_caption": True, "idefics2_prompt": "Describe this image in one detailed sentence, include the subject, location, style, and type of image."}
41 |
42 | def __init__(self, params: ProcessParams):
43 | super().__init__(params)
44 | logger.debug("Initializing LLM model...")
45 | model_path = fetch_model("HuggingFaceM4/idefics2-8b", "llm")
46 | model_name = get_model_name_from_path(model_path)
47 | self.load_4bit = params.load_in_4bit
48 | self.prompt = params.interrogation_prompt
49 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
50 | self.model = None
51 | self.tokenizer = None
52 | self.image_processor = None
53 | self.context_len = None
54 | self.current_device = None
55 | logger.debug("Initialized LLM model.")
56 |
57 | def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str:
58 | self.load()
59 | self.prompt = NO_CAP_PROMPT
60 | if params is not None:
61 | if params.new_caption != "":
62 | existing_caption_txt = params.new_caption
63 | self.prompt = f"{EX_CAP_PROMPT}: {existing_caption_txt}"
64 | logger.debug(f"Existing caption query: {self.prompt}")
65 |
66 | # Create inputs
67 | messages = [
68 | {
69 | "role": "user",
70 | "content": [
71 | {"type": "image"},
72 | {"type": "text", "text": self.prompt},
73 | ]
74 | }
75 | ]
76 |
77 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
78 | inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
79 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
80 |
81 | # Generate
82 | generated_ids = self.model.generate(**inputs, max_new_tokens=500)
83 | caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
84 | caption = caption[0].strip()
85 | # Find "Assistant: " in the string and return everything after it
86 | caption = caption[caption.find("Assistant: ") + len("Assistant: "):].strip()
87 | if params.txt_action != "include":
88 | caption = caption.replace(",", "").replace(".", "").replace("?", "").replace("!", "").strip()
89 | if unload:
90 | self.unload()
91 | return caption
92 |
93 | def _to_cpu(self):
94 | if self.load_4bit:
95 | print("Model is loaded in 8bit, can't move to CPU.")
96 | return
97 | from extensions.sd_smartprocess.smartprocess import vram_usage
98 | used, total = vram_usage()
99 | print(f"VRAM: {used}/{total}")
100 | free = total - used
101 | # If we have over 16GB of VRAM, we can use the GPU
102 | if free > 16:
103 | print("VRAM is over 16GB, moving to GPU")
104 | self._to_gpu()
105 | return
106 | print("Moving to CPU")
107 | if self.current_device != "cpu":
108 | self.model.to('cpu')
109 | self.current_device = "cpu"
110 | if torch.cuda.is_available():
111 | torch.cuda.empty_cache()
112 | gc.collect()
113 |
114 | def _to_gpu(self):
115 | if self.model is None:
116 | self.load()
117 | return
118 | if self.load_4bit:
119 | print("Model is loaded in 4bit, can't move to GPU.")
120 | return
121 | if self.current_device != "cuda" and torch.cuda.is_available():
122 | print("Moving to GPU")
123 | time = datetime.now()
124 | self.model.to(self.device)
125 | print(f"Model to GPU: {datetime.now() - time}")
126 | self.current_device = "cuda"
127 |
128 | def unload(self):
129 | if self.model is not None:
130 | print("Unloading Idefics2 model")
131 | self._to_cpu()
132 |
133 | def load(self):
134 | if self.model is None:
135 | model_path = fetch_model("HuggingFaceM4/idefics2-8b", "llm")
136 | # model_name = get_model_name_from_path(model_path)
137 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
138 | self.current_device = self.device
139 | print(f"Loading processor on {self.device}")
140 | self.processor = AutoProcessor.from_pretrained(model_path)
141 | if self.load_4bit:
142 | print(f"Loading model.")
143 | quantization_config = BitsAndBytesConfig(
144 | load_in_4bit=True,
145 | bnb_4bit_quant_type="nf4",
146 | bnb_4bit_use_double_quant=True,
147 | bnb_4bit_compute_dtype=torch.float16
148 | )
149 | self.model = AutoModelForVision2Seq.from_pretrained(
150 | model_path,
151 | torch_dtype=torch.float16,
152 | quantization_config=quantization_config
153 | )
154 | else:
155 | self.model = AutoModelForVision2Seq.from_pretrained(
156 | model_path,
157 | torch_dtype=torch.float16
158 | ).to(self.device)
159 | print(f"Model loaded on {self.device}")
160 | self._to_gpu()
161 |
--------------------------------------------------------------------------------
/interrogators/interrogator.py:
--------------------------------------------------------------------------------
1 | # Borrowed from https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/interrogator.py
2 | import importlib
3 | import pkgutil
4 | import re
5 | import sys
6 | from abc import abstractmethod
7 |
8 | import torch
9 | from PIL import Image
10 |
11 | import extensions.sd_smartprocess.interrogators as interrogators
12 | from extensions.sd_smartprocess.process_params import ProcessParams
13 |
14 |
15 | @abstractmethod
16 | class Interrogator:
17 | def __init__(self, params: ProcessParams) -> None:
18 | self.params = params
19 | self.model = None
20 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
21 | registry = InterrogatorRegistry()
22 | registry.register(self)
23 |
24 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
25 | raise NotImplementedError
26 |
27 | def unload(self):
28 | if self.model:
29 | try:
30 | self.model = self.model.to("cpu")
31 | except:
32 | pass
33 |
34 | def load(self):
35 | if self.model:
36 | try:
37 | self.model = self.model.to(self.device)
38 | except:
39 | pass
40 |
41 |
42 | re_special = re.compile(r'([\\()])')
43 |
44 |
45 | class InterrogatorRegistry:
46 | _instance = None # Class variable to hold the singleton instance
47 |
48 | def __new__(cls):
49 | if cls._instance is None:
50 | cls._instance = super(InterrogatorRegistry, cls).__new__(cls)
51 | cls._instance._init()
52 | return cls._instance
53 |
54 | def _init(self):
55 | self.interrogators = {}
56 |
57 | def register(self, interrogator: Interrogator):
58 | self.interrogators[interrogator.__class__.__name__] = interrogator
59 |
60 | def get_interrogator(self, interrogator_name: str) -> Interrogator:
61 | return self.interrogators[interrogator_name]
62 |
63 | def get_interrogators(self):
64 | return self.interrogators
65 |
66 | def unload(self):
67 | for interrogator in self.interrogators.values():
68 | interrogator.unload()
69 |
70 | def load(self):
71 | for interrogator in self.interrogators.values():
72 | interrogator.load()
73 |
74 | @staticmethod
75 | def list_interrogators(return_cls=False):
76 | # First, dynamically import all modules in the package to ensure all subclasses are loaded
77 | package = interrogators
78 | for importer, modname, ispkg in pkgutil.iter_modules(package.__path__, package.__name__ + '.'):
79 | print(f"Importing {modname}")
80 | try:
81 | importlib.import_module(modname)
82 | except Exception as e:
83 | print(f"Error importing {modname}: {e}")
84 |
85 | # Now, enumerate subclasses of Interrogator to access their class-level 'params'
86 | interrogator_dict = {}
87 | subclass_names = [cls.__name__ for cls in Interrogator.__subclasses__()]
88 | print(f"Subclass names: {subclass_names}")
89 | for cls in Interrogator.__subclasses__():
90 | if return_cls:
91 | interrogator_dict[cls.__name__] = cls
92 | continue
93 | params = getattr(cls, "params", None)
94 | if params is not None: # Ensure 'params' is defined at the class level
95 | interrogator_dict[cls.__name__] = params
96 | else:
97 | interrogator_dict[cls.__name__] = {}
98 |
99 | return interrogator_dict
100 |
101 |
--------------------------------------------------------------------------------
/interrogators/llava2_interrogator.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 | from datetime import datetime
4 |
5 | import torch
6 | from PIL import Image
7 | from transformers import TextStreamer
8 |
9 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
10 | from extensions.sd_smartprocess.model_download import fetch_model
11 | from extensions.sd_smartprocess.mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
12 | from extensions.sd_smartprocess.mplug_owl2.conversation import conv_templates
13 | from extensions.sd_smartprocess.mplug_owl2.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token, \
14 | process_images, \
15 | get_model_name_from_path
16 | from extensions.sd_smartprocess.mplug_owl2.model.builder import load_pretrained_model
17 | from extensions.sd_smartprocess.process_params import ProcessParams
18 |
19 | # This is basically broken until we can update transformers in AUTO past the current version supported
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | NO_CAP_PROMPT = """
24 | Generate a concise caption describing the key elements and context of the image in one sentence,
25 | focusing on the medium, subject, style, clothing, pose, action, and location. Ensure the sentence is accurate and devoid
26 | of assumptions, prose, etc. Keep it direct and relevant to the image.
27 |
28 | Follow the caption with a list of specific tags (keywords) detailing smaller key elements like clothing, poses,
29 | actions, and other notable features.
30 | """
31 |
32 | EX_CAP_PROMPT = """
33 | Here is a caption consisting of a sentence and a list of tags (keywords) that describe the image.
34 |
35 | Refine the provided caption to more accurately and vividly capture the essence and key details visible in the image,
36 | focusing on encapsulating the medium, subject, style, clothing, pose, action, and location in one sentence.
37 |
38 | Update the accompanying tags to reflect only the elements present in the image, ensuring precision and relevance.
39 | Avoid adding new information not supported by the existing caption or the image.
40 | """
41 |
42 |
43 | class LLAVA2Interrogator(Interrogator):
44 | model = None
45 | processor = None
46 | params = {"max_tokens": 75, "load_in_8bit": False, "replace_blip_caption": True}
47 |
48 | def __init__(self, params: ProcessParams):
49 | super().__init__(params)
50 | logger.debug("Initializing LLM model...")
51 | model_path = fetch_model('MAGAer13/mplug-owl2-llama2-7b', "llm")
52 | model_name = get_model_name_from_path(model_path)
53 | self.load_8bit = params.load_in_8bit
54 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
55 | self.model = None
56 | self.tokenizer = None
57 | self.image_processor = None
58 | self.context_len = None
59 | self.current_device = None
60 | logger.debug("Initialized LLM model.")
61 |
62 | def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str:
63 | self.load()
64 |
65 | query = NO_CAP_PROMPT
66 | if params is not None:
67 | if params.new_caption != "":
68 | existing_caption_txt = params.new_caption
69 | query = f"{EX_CAP_PROMPT}: {existing_caption_txt}"
70 | logger.debug(f"Existing caption query: {query}")
71 |
72 | conv = conv_templates["mplug_owl2"].copy()
73 |
74 | max_edge = max(image.size)
75 | image = image.resize((max_edge, max_edge))
76 | image_tensor = process_images([image], self.image_processor)
77 | image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
78 |
79 | inp = query + " " + DEFAULT_IMAGE_TOKEN
80 | conv.append_message(conv.roles[0], inp)
81 | conv.append_message(conv.roles[1], None)
82 | prompt = conv.get_prompt()
83 |
84 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
85 | 0).to(
86 | self.model.device)
87 | stop_str = conv.sep2
88 | keywords = [stop_str]
89 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
90 | streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
91 |
92 | temperature = 0.7
93 | max_new_tokens = 512
94 |
95 | with torch.inference_mode():
96 | output_ids = self.model.generate(
97 | input_ids,
98 | images=image_tensor,
99 | do_sample=True,
100 | temperature=temperature,
101 | max_new_tokens=max_new_tokens,
102 | streamer=streamer,
103 | use_cache=True,
104 | stopping_criteria=[stopping_criteria])
105 |
106 | caption = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
107 | if params.txt_action != "include":
108 | caption = caption.replace(",", "").replace(".", "").replace("?", "").replace("!", "").strip()
109 | if unload:
110 | self.unload()
111 | return caption
112 |
113 | def _to_cpu(self):
114 | if self.load_8bit:
115 | print("Model is loaded in 8bit, can't move to CPU.")
116 | return
117 | from extensions.sd_smartprocess.smartprocess import vram_usage
118 | used, total = vram_usage()
119 | print(f"VRAM: {used}/{total}")
120 | free = total - used
121 | # If we have over 16GB of VRAM, we can use the GPU
122 | if free > 16:
123 | print("VRAM is over 16GB, moving to GPU")
124 | self._to_gpu()
125 | return
126 | print("Moving to CPU")
127 | if self.current_device != "cpu":
128 | self.model.to('cpu')
129 | self.current_device = "cpu"
130 | if torch.cuda.is_available():
131 | torch.cuda.empty_cache()
132 | gc.collect()
133 |
134 | def _to_gpu(self):
135 | if self.model is None:
136 | self.load()
137 | return
138 | if self.current_device != "cuda" and torch.cuda.is_available():
139 | print("Moving to GPU")
140 | time = datetime.now()
141 | self.model.to(self.device)
142 | print(f"Model to GPU: {datetime.now() - time}")
143 | self.current_device = "cuda"
144 | # self.image_processor.to(self.device)
145 | # self.tokenizer.to(self.device)
146 |
147 | def unload(self):
148 | if self.model is not None:
149 | print("Unloading LLAVA2 model")
150 | self._to_cpu()
151 |
152 | def load(self):
153 | if self.model is None:
154 | model_path = fetch_model('MAGAer13/mplug-owl2-llama2-7b', "llm")
155 | model_name = get_model_name_from_path(model_path)
156 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
157 | self.current_device = self.device
158 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None,
159 | model_name,
160 | load_8bit=self.load_8bit,
161 | load_4bit=False,
162 | device="cuda")
163 |
164 | self._to_gpu()
165 |
--------------------------------------------------------------------------------
/interrogators/moondream_interrogator.py:
--------------------------------------------------------------------------------
1 | from PIL.Image import Image
2 |
3 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
4 | from model_download import fetch_model
5 | from process_params import ProcessParams
6 | from torch import float16
7 | from transformers import AutoModelForCausalLM, AutoTokenizer
8 |
9 |
10 | class MoonDreamInterrogator(Interrogator):
11 | params = {"interrogation_prompt": "Describe this image in one detailed sentence."} # Class-level attribute
12 |
13 | def __init__(self, params: ProcessParams):
14 | super().__init__(params)
15 | self.interrogation_prompt = params.interrogation_prompt
16 | self.model = None
17 | self.tokenizer = None
18 | print("Initializing Moondream...")
19 |
20 | def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str:
21 | self.load()
22 | enc_image = self.model.encode_image(image)
23 | caption = self.model.answer_question(enc_image, "Describe this image in one detailed sentence.", self.tokenizer)
24 | caption = caption.replace(",", "").replace(".", "").replace("?", "").replace("!", "").strip()
25 | if unload:
26 | self.unload()
27 | return caption
28 |
29 | def _to_gpu(self):
30 | self.model = self.model.to("cuda")
31 |
32 | def _to_cpu(self):
33 | self.model = self.model.to("cpu")
34 |
35 | def load(self):
36 | if self.model is None:
37 | model_id = "vikhyatk/moondream2"
38 | model_path = fetch_model(model_id, "llm")
39 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
40 |
41 | self.model = AutoModelForCausalLM.from_pretrained(
42 | model_path, trust_remote_code=True,
43 | torch_dtype=float16).to("cuda")
44 |
--------------------------------------------------------------------------------
/interrogators/wolf_interrogator.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import PIL
4 | import cv2
5 | import huggingface_hub
6 | import numpy as np
7 | import pandas as pd
8 | from PIL.Image import Image
9 | from onnxruntime import InferenceSession
10 |
11 | from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
12 | from extensions.sd_smartprocess.model_download import fetch_model
13 | from extensions.sd_smartprocess.process_params import ProcessParams
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
18 | MODEL_FILENAME = "model.onnx"
19 | LABEL_FILENAME = "selected_tags.csv"
20 | WOLF_PARAMS = {
21 | "group": "WOLF",
22 | "threshold": 0.75,
23 | "char_threshold": 0.75
24 | }
25 |
26 |
27 | class MoatInterrogator(Interrogator):
28 | params = WOLF_PARAMS
29 |
30 | def __init__(self, params: ProcessParams) -> None:
31 | super().__init__(params)
32 | self._setup()
33 |
34 | def _setup(self):
35 | model_path = "SmilingWolf/wd-v1-4-moat-tagger-v2"
36 | self.model = load_model(model_path, self.device)
37 |
38 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
39 | threshold = params.threshold
40 | char_threshold = params.char_threshold
41 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
42 | return a
43 |
44 |
45 | class SwinInterrogator(Interrogator):
46 | params = WOLF_PARAMS
47 |
48 | def __init__(self, params: ProcessParams) -> None:
49 | super().__init__(params)
50 | self._setup()
51 |
52 | def _setup(self):
53 | model_path = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
54 | self.model = load_model(model_path, self.device)
55 |
56 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
57 | threshold = params.threshold
58 | char_threshold = params.char_threshold
59 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
60 | return a
61 |
62 |
63 | class ConvInterrogator(Interrogator):
64 | params = WOLF_PARAMS
65 |
66 | def __init__(self, params: ProcessParams) -> None:
67 | super().__init__(params)
68 | self._setup()
69 |
70 | def _setup(self):
71 | model_path = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
72 | self.model = load_model(model_path, self.device)
73 |
74 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
75 | threshold = params.threshold
76 | char_threshold = params.char_threshold
77 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
78 | return a
79 |
80 |
81 | class Conv2Interrogator(Interrogator):
82 | params = WOLF_PARAMS
83 |
84 | def __init__(self, params: ProcessParams) -> None:
85 | super().__init__(params)
86 | self._setup()
87 |
88 | def _setup(self):
89 | model_path = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
90 | self.model = load_model(model_path, self.device)
91 |
92 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
93 | threshold = params.threshold
94 | char_threshold = params.char_threshold
95 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
96 | return a
97 |
98 |
99 | class VitInterrogator(Interrogator):
100 | params = WOLF_PARAMS
101 |
102 | def __init__(self, params: ProcessParams) -> None:
103 | super().__init__(params)
104 | self._setup()
105 |
106 | def _setup(self):
107 | model_path = "SmilingWolf/wd-v1-4-vit-tagger-v2"
108 | self.model = load_model(model_path, self.device)
109 |
110 | def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
111 | threshold = params.threshold
112 | char_threshold = params.char_threshold
113 | a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
114 | return a
115 |
116 |
117 | def predict(
118 | image: Image,
119 | model,
120 | general_threshold: float = 0.5,
121 | character_threshold: float = 0.5
122 | ):
123 | tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
124 |
125 | _, height, width, _ = model.get_inputs()[0].shape
126 |
127 | # Alpha to white
128 | image = image.convert("RGBA")
129 | new_image = PIL.Image.new("RGBA", image.size, "WHITE")
130 | new_image.paste(image, mask=image)
131 | image = new_image.convert("RGB")
132 | image = np.asarray(image)
133 |
134 | # PIL RGB to OpenCV BGR
135 | image = image[:, :, ::-1]
136 |
137 | image = make_square(image, height)
138 | image = smart_resize(image, height)
139 | image = image.astype(np.float32)
140 | image = np.expand_dims(image, 0)
141 |
142 | input_name = model.get_inputs()[0].name
143 | label_name = model.get_outputs()[0].name
144 | probs = model.run([label_name], {input_name: image})[0]
145 |
146 | labels = list(zip(tag_names, probs[0].astype(float)))
147 |
148 | # First 4 labels are actually ratings: pick one with argmax
149 | ratings_names = [labels[i] for i in rating_indexes]
150 | rating = dict(ratings_names)
151 |
152 | # Then we have general tags: pick any where prediction confidence > threshold
153 | general_names = [labels[i] for i in general_indexes]
154 | general_res = [x for x in general_names if x[1] > general_threshold]
155 | general_res = dict(general_res)
156 |
157 | # Everything else is characters: pick any where prediction confidence > threshold
158 | character_names = [labels[i] for i in character_indexes]
159 | character_res = [x for x in character_names if x[1] > character_threshold]
160 | character_res = dict(character_res)
161 |
162 | b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
163 | a = (
164 | ", ".join(list(b.keys()))
165 | .replace("_", " ")
166 | .replace("(", "\(")
167 | .replace(")", "\)")
168 | )
169 | c = ", ".join(list(b.keys()))
170 |
171 | return a, c, rating, character_res, general_res
172 |
173 |
174 | def load_model(model_path, device):
175 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
176 | if device == "cpu":
177 | providers.pop(0)
178 | path = fetch_model(model_path, "wolf", True)
179 | model = InferenceSession(path, providers=providers)
180 | return model
181 |
182 |
183 | def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
184 | if img.endswith(".gif"):
185 | img = PIL.Image.open(img)
186 | img = img.convert("RGB")
187 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
188 | else:
189 | img = cv2.imread(img, flag)
190 | return img
191 |
192 |
193 | def smart_24bit(img):
194 | if img.dtype is np.dtype(np.uint16):
195 | img = (img / 257).astype(np.uint8)
196 |
197 | if len(img.shape) == 2:
198 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
199 | elif img.shape[2] == 4:
200 | trans_mask = img[:, :, 3] == 0
201 | img[trans_mask] = [255, 255, 255, 255]
202 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
203 | return img
204 |
205 |
206 | def make_square(img, target_size):
207 | old_size = img.shape[:2]
208 | desired_size = max(old_size)
209 | desired_size = max(desired_size, target_size)
210 |
211 | delta_w = desired_size - old_size[1]
212 | delta_h = desired_size - old_size[0]
213 | top, bottom = delta_h // 2, delta_h - (delta_h // 2)
214 | left, right = delta_w // 2, delta_w - (delta_w // 2)
215 |
216 | color = [255, 255, 255]
217 | new_im = cv2.copyMakeBorder(
218 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
219 | )
220 | return new_im
221 |
222 |
223 | def smart_resize(img, size):
224 | # Assumes the image has already gone through make_square
225 | if img.shape[0] > size:
226 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
227 | elif img.shape[0] < size:
228 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
229 | return img
230 |
231 |
232 | def load_labels() -> list[str]:
233 | path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME)
234 | df = pd.read_csv(path)
235 | tag_names = df["name"].tolist()
236 | rating_indexes = list(np.where(df["category"] == 9)[0])
237 | general_indexes = list(np.where(df["category"] == 0)[0])
238 | character_indexes = list(np.where(df["category"] == 4)[0])
239 | return tag_names, rating_indexes, general_indexes, character_indexes
240 |
--------------------------------------------------------------------------------
/javascript/smart_process.js:
--------------------------------------------------------------------------------
1 | function start_smart_process() {
2 | let progress = gradioApp().getElementById("sp_progress");
3 | let gallery = gradioApp().getElementById("sp_gallery");
4 | console.log("Requesting progress:", progress, gallery);
5 | requestProgress('sp', progress, gallery, atEnd, function(progress){});
6 | gradioApp().querySelector('#sp_error').innerHTML = '';
7 | return args_to_array(arguments);
8 | }
9 |
10 | function atEnd() {
11 |
12 | }
13 |
--------------------------------------------------------------------------------
/model_download.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from huggingface_hub import snapshot_download
5 |
6 | from modules.paths_internal import models_path
7 | from modules.safe import unsafe_torch_load, load
8 |
9 |
10 | def fetch_model(model_repo, model_type, single_file=False):
11 | model_dir = os.path.join(models_path, model_type)
12 | dest_dir = os.path.join(model_dir, model_repo.split("/")[1])
13 | if not os.path.exists(dest_dir):
14 | os.makedirs(dest_dir, exist_ok=True)
15 | dest_dir = snapshot_download(model_repo, repo_type="model", local_dir=dest_dir, local_dir_use_symlinks=False)
16 | if single_file:
17 | dest_dir = os.path.join(dest_dir, "model.onnx")
18 | return dest_dir
19 |
20 |
21 | disable_safe_unpickle_count = 0
22 |
23 |
24 | def disable_safe_unpickle():
25 | global disable_safe_unpickle_count
26 | try:
27 | from modules import shared as auto_shared
28 | if not auto_shared.cmd_opts.disable_safe_unpickle:
29 | auto_shared.cmd_opts.disable_safe_unpickle = True
30 | torch.load = unsafe_torch_load
31 | disable_safe_unpickle_count += 1
32 | except:
33 | pass
34 |
35 |
36 | def enable_safe_unpickle():
37 | global disable_safe_unpickle_count
38 | try:
39 | from modules import shared as auto_shared
40 | if disable_safe_unpickle_count > 0:
41 | disable_safe_unpickle_count -= 1
42 | if disable_safe_unpickle_count == 0 and auto_shared.cmd_opts.disable_safe_unpickle:
43 | auto_shared.cmd_opts.disable_safe_unpickle = False
44 | torch.load = load
45 | except:
46 | pass
47 |
--------------------------------------------------------------------------------
/mplug_owl2/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import MPLUGOwl2LlamaForCausalLM
--------------------------------------------------------------------------------
/mplug_owl2/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "./demo_logs"
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = "<|image|>"
10 |
--------------------------------------------------------------------------------
/mplug_owl2/evaluate/EVALUATION.md:
--------------------------------------------------------------------------------
1 | # Evaluation
2 | ## Important
3 | We use `batch_size = 1` for evaluation due to the batch inference issue in LLaMA if you do not suffer performance drop. (Left padding would cause the wrong positional embedding.)
4 |
5 | ## Dependencies
6 |
7 | ```bash
8 | pip install pycocoevalcap tqdm
9 | ```
10 |
11 | ## Image Caption
12 |
13 | ### [Flickr30K](https://bryanplummer.com/Flickr30kEntities/)
14 |
15 |
16 | Data Preparation
17 |
18 | ```bash
19 | mkdir -p data/flickr && cd data/flickr
20 |
21 | # download images from https://bryanplummer.com/Flickr30kEntities/
22 |
23 | # karpathy split annotations can be downloaded from https://cs.stanford.edu/people/karpathy/deepimagesent/
24 |
25 | # download converted files
26 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/flickr30k/flickr30k_karpathy_test.json
27 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/flickr30k/flickr30k_karpathy_train.json
28 |
29 | cd ../..
30 | ```
31 |
32 |
33 |
34 |
35 | Evaluate
36 |
37 | ```bash
38 | ds="flickr"
39 | checkpoint=/PATH/TO/CHECKPOINT
40 | python -m torch.distributed.launch --use-env \
41 | --nproc_per_node ${NPROC_PER_NODE:-8} \
42 | --nnodes ${WORLD_SIZE:-1} \
43 | --node_rank ${RANK:-0} \
44 | --master_addr ${MASTER_ADDR:-127.0.0.1} \
45 | --master_port ${MASTER_PORT:-12345} \
46 | evaluate_caption.py \
47 | --checkpoint $checkpoint \
48 | --dataset $ds \
49 | --batch-size 8 \
50 | --num-workers 2
51 | ```
52 |
53 |
54 |
55 |
56 | Evaluate
57 |
58 | ```bash
59 | ds="nocaps"
60 | checkpoint=/PATH/TO/CHECKPOINT
61 | python -m torch.distributed.launch --use-env \
62 | --nproc_per_node ${NPROC_PER_NODE:-8} \
63 | --nnodes ${WORLD_SIZE:-1} \
64 | --node_rank ${RANK:-0} \
65 | --master_addr ${MASTER_ADDR:-127.0.0.1} \
66 | --master_port ${MASTER_PORT:-12345} \
67 | evaluate_caption.py \
68 | --checkpoint $checkpoint \
69 | --dataset $ds \
70 | --batch-size 8 \
71 | --num-workers 2
72 | ```
73 |
74 |
75 |
76 | ## [COCO](https://cocodataset.org/)
77 |
78 | > COCO images are used in VQAv2/OK-VQA/RefCOCO/RefCOCO+/RefCOCOg, make sure you have already downloaded COCO images before evaluate on these benchmarks.
79 |
80 |
81 | Data Preparation
82 |
83 | ```bash
84 | mkdir -p data/coco && cd data/coco
85 |
86 | # download coco2014 images
87 | wget http://images.cocodataset.org/zips/train2014.zip && unzip train2014.zip
88 | wget http://images.cocodataset.org/zips/val2014.zip && unzip val2014.zip
89 | wget http://images.cocodataset.org/zips/test2015.zip && unzip test2015.zip
90 |
91 | cd ../..
92 | ```
93 |
94 |
95 |
96 | ## General VQA
97 |
98 | ### [VQAv2](https://visualqa.org/)
99 |
100 |
101 | Data Preparation
102 |
103 | ```bash
104 | mkdir -p data/vqav2 && cd data/vqav2
105 |
106 | # make sure you have downloaded COCO images
107 |
108 | # download questions and annotations
109 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip && unzip v2_Annotations_Train_mscoco.zip
110 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip && unzip v2_Questions_Train_mscoco.zip
111 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip && unzip v2_Annotations_Val_mscoco.zip
112 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip && unzip v2_Questions_Val_mscoco.zip
113 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip && unzip v2_Questions_Test_mscoco.zip
114 |
115 | # download converted files
116 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/vqav2/vqav2_train.jsonl
117 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/vqav2/vqav2_val.jsonl
118 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/vqav2/vqav2_testdev.jsonl
119 | ```
120 |
121 |
122 |
123 |
124 | Evaluate
125 |
126 | ```bash
127 | checkpoint=/PATH/TO/CHECKPOINT
128 | for ds in "vqav2_val" "vqav2_testdev"
129 | python -m torch.distributed.launch --use-env \
130 | --nproc_per_node ${NPROC_PER_NODE:-8} \
131 | --nnodes ${WORLD_SIZE:-1} \
132 | --node_rank ${RANK:-0} \
133 | --master_addr ${MASTER_ADDR:-127.0.0.1} \
134 | --master_port ${MASTER_PORT:-12345} \
135 | evaluate_vqa.py \
136 | --checkpoint $checkpoint \
137 | --dataset $ds \
138 | --batch-size 8 \
139 | --num-workers 2
140 | ```
141 |
142 |
143 |
144 | ### [OKVQA](https://okvqa.allenai.org/)
145 |
146 |
147 | Data Preparation
148 |
149 | ```bash
150 | mkdir -p data/okvqa && cd data/okvqa
151 |
152 | # download annotations and questions
153 | wget https://okvqa.allenai.org/static/data/mscoco_train2014_annotations.json.zip && unzip mscoco_train2014_annotations.json.zip
154 | wget https://okvqa.allenai.org/static/data/OpenEnded_mscoco_train2014_questions.json.zip && unzip OpenEnded_mscoco_train2014_questions.json.zip
155 | wget https://okvqa.allenai.org/static/data/mscoco_val2014_annotations.json.zip && unzip mscoco_val2014_annotations.json.zip
156 | wget https://okvqa.allenai.org/static/data/OpenEnded_mscoco_val2014_questions.json.zip && unzip OpenEnded_mscoco_val2014_questions.json.zip
157 |
158 | # download converted files
159 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/okvqa/okvqa_train.jsonl
160 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/okvqa/okvqa_val.jsonl
161 |
162 | cd ../..
163 | ```
164 |
165 |
166 |
167 |
168 | Evaluate
169 |
170 | ```bash
171 | ds="okvqa_val"
172 | checkpoint=/PATH/TO/CHECKPOINT
173 | python -m torch.distributed.launch --use-env \
174 | --nproc_per_node ${NPROC_PER_NODE:-8} \
175 | --nnodes ${WORLD_SIZE:-1} \
176 | --node_rank ${RANK:-0} \
177 | --master_addr ${MASTER_ADDR:-127.0.0.1} \
178 | --master_port ${MASTER_PORT:-12345} \
179 | evaluate_vqa.py \
180 | --checkpoint $checkpoint \
181 | --dataset $ds \
182 | --batch-size 8 \
183 | --num-workers 2
184 | ```
185 |
186 |
187 |
188 | ### [TextVQA](https://textvqa.org/)
189 |
190 |
191 | Data Preparation
192 |
193 | ```bash
194 | mkdir -p data/textvqa && cd data/textvqa
195 |
196 | # download images
197 | wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip && unzip train_val_images.zip
198 |
199 | # download annotations and questions
200 | wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_train.json
201 | wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
202 |
203 | # download converted files
204 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_train_annotations.json
205 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_train_questions.json
206 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_train.jsonl
207 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_val_annotations.json
208 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_val_questions.json
209 | wget https://ofasys-wlcb.oss-cn-wulanchabu.aliyuncs.com/Qwen-VL/evaluation/textvqa/textvqa_val.jsonl
210 |
211 | cd ../..
212 | ```
213 |
214 |
215 |
216 | Evaluate
217 |
218 | ```bash
219 | ds="textvqa_val"
220 | checkpoint=/PATH/TO/CHECKPOINT
221 | python -m torch.distributed.launch --use-env \
222 | --nproc_per_node ${NPROC_PER_NODE:-8} \
223 | --nnodes ${WORLD_SIZE:-1} \
224 | --node_rank ${RANK:-0} \
225 | --master_addr ${MASTER_ADDR:-127.0.0.1} \
226 | --master_port ${MASTER_PORT:-12345} \
227 | evaluate_vqa.py \
228 | --checkpoint $checkpoint \
229 | --dataset $ds \
230 | --batch-size 8 \
231 | --num-workers 2
232 | ```
233 |
234 |
235 |
236 |
237 | ## Multi-modal Benchmark
238 | ### [MMBench](https://textvqa.org/)
239 |
240 |
241 | Data Preparation
242 |
243 | ```bash
244 | python mmbench_converter.py # To convert jsonl file.
245 | ```
246 |
247 |
248 |
249 | Evaluate
250 |
251 | ```bash
252 | ds="mmbench_dev_20230712"
253 | checkpoint=/PATH/TO/CHECKPOINT
254 | python -m torch.distributed.launch --use-env \
255 | --nproc_per_node ${NPROC_PER_NODE:-8} \
256 | --nnodes ${WORLD_SIZE:-1} \
257 | --node_rank ${RANK:-0} \
258 | --master_addr ${MASTER_ADDR:-127.0.0.1} \
259 | --master_port ${MASTER_PORT:-12345} \
260 | evaluate_mmbench.py \
261 | --checkpoint $checkpoint \
262 | --dataset $ds \
263 | --batch-size 1 \
264 | --num-workers 2
265 | ```
266 |
267 |
--------------------------------------------------------------------------------
/mplug_owl2/evaluate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/evaluate/__init__.py
--------------------------------------------------------------------------------
/mplug_owl2/evaluate/evaluate_caption.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 |
9 | import torch
10 | from pycocoevalcap.eval import COCOEvalCap
11 | from pycocotools.coco import COCO
12 | from tqdm import tqdm
13 | from PIL import Image
14 |
15 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
16 | from mplug_owl2.conversation import conv_templates, SeparatorStyle
17 | from mplug_owl2.model.builder import load_pretrained_model
18 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
19 |
20 | ds_collections = {
21 | 'flickr': {
22 | 'train': 'data/flickr30k/flickr30k_karpathy_test.json',
23 | 'test': 'data/flickr30k/flickr30k_karpathy_test.json',
24 | },
25 | }
26 |
27 | class CaptionDataset(torch.utils.data.Dataset):
28 |
29 | def __init__(self, train, test, prompt, image_processor, few_shot=0):
30 | self.images = json.load(open(test))['images']
31 | self.prompt = prompt
32 | self.image_processor = image_processor
33 |
34 | self.few_shot = few_shot
35 | if few_shot > 0:
36 | self.train = json.load(open(train))['annotations']
37 |
38 | def __len__(self):
39 | return len(self.images)
40 |
41 | def __getitem__(self, idx):
42 | image_id, image_path = self.images[idx]['id'], self.images[idx]['image']
43 |
44 | image = Image.open(image_path).convert('RGB')
45 | max_edge = max(image.size)
46 | image = image.resize((max_edge, max_edge))
47 | image_tensor = process_images([image], self.image_processor)
48 |
49 | return {
50 | 'image_id': image_id,
51 | 'image_tensor': image_tensor,
52 | 'input_text': self.prompt.format(image_path)
53 | }
54 |
55 |
56 | def collate_fn(inputs, tokenizer):
57 | image_ids = [_['image_id'] for _ in inputs]
58 | image_tensor = [_['image_tensor'] for _ in inputs]
59 | input_texts = [_['input_text'] for _ in inputs]
60 |
61 | input_ids = []
62 | for input_text in input_texts:
63 | input_ids.append(tokenizer_image_token(input_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').tolist())
64 | input_tokens_max_length = max([len(x) for x in input_ids])
65 | pad_token_id = tokenizer.pad_token_id
66 |
67 | input_ids = [([pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left
68 | input_ids = torch.LongTensor(input_ids)
69 | attention_mask = 1 - input_ids.eq(pad_token_id).long()
70 |
71 | image_tensor = torch.cat(image_tensor, dim=0)
72 | return image_ids, image_tensor, input_ids, attention_mask
73 |
74 |
75 | class InferenceSampler(torch.utils.data.sampler.Sampler):
76 |
77 | def __init__(self, size):
78 | self._size = int(size)
79 | assert size > 0
80 | self._rank = torch.distributed.get_rank()
81 | self._world_size = torch.distributed.get_world_size()
82 | self._local_indices = self._get_local_indices(size, self._world_size,
83 | self._rank)
84 |
85 | @staticmethod
86 | def _get_local_indices(total_size, world_size, rank):
87 | shard_size = total_size // world_size
88 | left = total_size % world_size
89 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
90 |
91 | begin = sum(shard_sizes[:rank])
92 | end = min(sum(shard_sizes[:rank + 1]), total_size)
93 | return range(begin, end)
94 |
95 | def __iter__(self):
96 | yield from self._local_indices
97 |
98 | def __len__(self):
99 | return len(self._local_indices)
100 |
101 |
102 | if __name__ == '__main__':
103 |
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument('--checkpoint', type=str, default='')
106 | parser.add_argument('--dataset', type=str, default='flickr')
107 | parser.add_argument('--batch-size', type=int, default=1)
108 | parser.add_argument('--num-workers', type=int, default=1)
109 | parser.add_argument('--few-shot', type=int, default=0)
110 | parser.add_argument('--seed', type=int, default=0)
111 | args = parser.parse_args()
112 |
113 | torch.distributed.init_process_group(
114 | backend='nccl',
115 | world_size=int(os.getenv('WORLD_SIZE', '1')),
116 | rank=int(os.getenv('RANK', '0')),
117 | )
118 |
119 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
120 |
121 | prompt = 'USER: <|image|>Provide a one-sentence caption for the provided image. ASSISTANT: '
122 |
123 | model_path = args.checkpoint
124 | model_name = get_model_name_from_path(model_path)
125 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
126 | tokenizer.padding_side = 'left'
127 | if not hasattr(tokenizer, 'pad_token_id'):
128 | tokenizer.pad_token_id = tokenizer.eos_token_id
129 |
130 | random.seed(args.seed)
131 | dataset = CaptionDataset(
132 | train=ds_collections[args.dataset]['train'],
133 | test=ds_collections[args.dataset]['test'],
134 | prompt=prompt,
135 | image_processor=image_processor,
136 | few_shot=args.few_shot,
137 | )
138 | coco_karpathy_test_loader = torch.utils.data.DataLoader(
139 | dataset=dataset,
140 | sampler=InferenceSampler(len(dataset)),
141 | batch_size=args.batch_size,
142 | num_workers=args.num_workers,
143 | pin_memory=True,
144 | drop_last=False,
145 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
146 | )
147 |
148 | image_ids = []
149 | captions = []
150 | for _, (ids, image_tensor, input_ids, attention_mask) in enumerate(tqdm(coco_karpathy_test_loader)):
151 | pred = model.generate(
152 | input_ids=input_ids.cuda(),
153 | attention_mask=attention_mask.cuda(),
154 | images=image_tensor.to(dtype=model.dtype).cuda(),
155 | do_sample=False,
156 | num_beams=1,
157 | max_new_tokens=60,
158 | min_new_tokens=8,
159 | length_penalty=0,
160 | num_return_sequences=1,
161 | use_cache=True,
162 | )
163 | image_ids.extend(ids)
164 | captions.extend([
165 | tokenizer.decode(_[input_ids.size(1):].cpu(),
166 | skip_special_tokens=True).strip() for _ in pred
167 | ])
168 | print(captions[-len(pred):])
169 |
170 | torch.distributed.barrier()
171 |
172 | world_size = torch.distributed.get_world_size()
173 | merged_ids = [None for _ in range(world_size)]
174 | merged_captions = [None for _ in range(world_size)]
175 | torch.distributed.all_gather_object(merged_ids, image_ids)
176 | torch.distributed.all_gather_object(merged_captions, captions)
177 |
178 | merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)]
179 | merged_captions = [
180 | _ for _ in itertools.chain.from_iterable(merged_captions)
181 | ]
182 |
183 | if torch.distributed.get_rank() == 0:
184 | print(f"Evaluating {args.dataset} ...")
185 |
186 | results = []
187 | for image_id, caption in zip(merged_ids, merged_captions):
188 | results.append({
189 | 'image_id': int(image_id),
190 | 'caption': caption,
191 | })
192 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
193 | results_file = f'{args.dataset}_{time_prefix}.json'
194 | json.dump(results, open(results_file, 'w'))
195 |
196 | coco = COCO(ds_collections[args.dataset]['test'])
197 | coco_result = coco.loadRes(results_file)
198 | coco_eval = COCOEvalCap(coco, coco_result)
199 | coco_eval.evaluate()
200 |
201 | print(coco_eval.eval.items())
202 | torch.distributed.barrier()
--------------------------------------------------------------------------------
/mplug_owl2/evaluate/evaluate_mmbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 | from typing import Optional
9 |
10 | import torch
11 | from tqdm import tqdm
12 | from PIL import Image
13 | import pandas as pd
14 |
15 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
16 | from mplug_owl2.conversation import conv_templates, SeparatorStyle
17 | from mplug_owl2.model.builder import load_pretrained_model
18 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
19 |
20 |
21 | ds_collections = {
22 | 'mmbench_dev_20230712': {
23 | 'raw_file': 'mmbench_dev_20230712.tsv',
24 | 'annotation': 'mmbench_dev_20230712.jsonl',
25 | 'max_new_tokens': 10,
26 | },
27 | 'mmbench_test_20230712': {
28 | 'raw_file': 'mmbench_test_20230712.tsv',
29 | 'annotation': 'mmbench_test_20230712.jsonl',
30 | 'max_new_tokens': 10,
31 | },
32 | 'mmbench_test_en_20231003': {
33 | 'raw_file': 'mmbench_test_en_20231003.tsv',
34 | 'annotation': 'mmbench_test_en_20231003.jsonl',
35 | 'max_new_tokens': 10,
36 | },
37 | 'mmbench_test_cn_20231003': {
38 | 'raw_file': 'mmbench_test_cn_20231003.tsv',
39 | 'annotation': 'mmbench_test_cn_20231003.jsonl',
40 | 'max_new_tokens': 10,
41 | },
42 | 'ccbench_1003': {
43 | 'raw_file': 'ccbench_1003.tsv',
44 | 'annotation': 'ccbench_1003.jsonl',
45 | 'max_new_tokens': 10,
46 | },
47 | }
48 |
49 | multiple_choices = ['A', 'B', 'C', 'D', 'E']
50 |
51 | def mapping_to_annotation(results, raw_annotation):
52 | outputs = []
53 | for result in results:
54 | index, prediction = result['index'], result['prediction']
55 | row_df = raw_annotation[raw_annotation['index'] == index].squeeze().to_dict()
56 | output = {
57 | "index": index,
58 | "image": row_df['image'],
59 | "question": row_df['question'],
60 | "answer": row_df.get('answer', None),
61 | "options": [y for y in [row_df.get(x, None) for x in 'ABCD'] if isinstance(y, str)],
62 | "prediction": prediction,
63 | "l2-category": row_df['l2-category'] if 'l2-category' in row_df else None
64 | }
65 | outputs.append(output)
66 | return outputs
67 |
68 |
69 | def generate_submission_file(results, raw_annotation):
70 | outputs = []
71 | for result in results:
72 | index, prediction = result['index'], result['prediction']
73 | row_df = raw_annotation[raw_annotation['index'] == index].squeeze().to_dict()
74 | output = {
75 | "index": index,
76 | "question": row_df['question'],
77 | "prediction": prediction,
78 | "A": row_df.get('A', None),
79 | "B": row_df.get('B', None),
80 | "C": row_df.get('C', None),
81 | "D": row_df.get('D', None),
82 | }
83 | outputs.append(output)
84 | return outputs
85 |
86 | def collate_fn(batches, tokenizer):
87 |
88 | questions = [_['question'] for _ in batches]
89 | indices = [_['index'] for _ in batches]
90 |
91 | image_tensor = [_['image_tensor'] for _ in batches]
92 |
93 | input_ids = []
94 | for input_text in questions:
95 | input_ids.append(tokenizer_image_token(input_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').tolist())
96 | input_tokens_max_length = max([len(x) for x in input_ids])
97 | pad_token_id = tokenizer.pad_token_id
98 |
99 | input_ids = [([pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left
100 | input_ids = torch.LongTensor(input_ids)
101 | attention_mask = 1 - input_ids.eq(pad_token_id).long()
102 |
103 | image_tensor = torch.cat(image_tensor, dim=0)
104 | return image_tensor, input_ids, attention_mask, indices
105 |
106 |
107 | class VQADataset(torch.utils.data.Dataset):
108 |
109 | def __init__(self, test, prompt, image_processor):
110 |
111 | self.prompt = prompt
112 | self.image_processor = image_processor
113 |
114 | self.data = open(test).readlines()
115 |
116 | def __len__(self):
117 | return len(self.data)
118 |
119 | def __getitem__(self, idx):
120 | data = json.loads(self.data[idx].strip())
121 | index = data['index']
122 | image = data['image']
123 | hint = data['hint'] if data['hint'] else 'N/A'
124 | question = data['question']
125 |
126 | choices = data['choices']
127 | choice_list = []
128 | for i, c in enumerate(choices):
129 | choice_list.append('{}. {}'.format(multiple_choices[i], c))
130 | choice_txt = '\n'.join(choice_list)
131 |
132 | image = Image.open(image).convert('RGB')
133 | max_edge = max(image.size)
134 | image = image.resize((max_edge, max_edge)) # Resize here for best performance
135 | image_tensor = process_images([image], self.image_processor)
136 |
137 | return {
138 | 'index': index,
139 | 'image_tensor': image_tensor,
140 | 'question': self.prompt.format(hint, question, choice_txt),
141 | }
142 |
143 |
144 | class InferenceSampler(torch.utils.data.sampler.Sampler):
145 |
146 | def __init__(self, size):
147 | self._size = int(size)
148 | assert size > 0
149 | self._rank = torch.distributed.get_rank()
150 | self._world_size = torch.distributed.get_world_size()
151 | self._local_indices = self._get_local_indices(size, self._world_size,
152 | self._rank)
153 |
154 | @staticmethod
155 | def _get_local_indices(total_size, world_size, rank):
156 | shard_size = total_size // world_size
157 | left = total_size % world_size
158 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
159 |
160 | begin = sum(shard_sizes[:rank])
161 | end = min(sum(shard_sizes[:rank + 1]), total_size)
162 | return range(begin, end)
163 |
164 | def __iter__(self):
165 | yield from self._local_indices
166 |
167 | def __len__(self):
168 | return len(self._local_indices)
169 |
170 |
171 | if __name__ == '__main__':
172 |
173 | parser = argparse.ArgumentParser()
174 | parser.add_argument('--checkpoint', type=str, default='')
175 | parser.add_argument('--dataset', type=str, default='mmbench_dev_20230712')
176 | parser.add_argument('--batch-size', type=int, default=1)
177 | parser.add_argument('--num-workers', type=int, default=1)
178 | parser.add_argument('--seed', type=int, default=0)
179 | args = parser.parse_args()
180 |
181 |
182 | torch.distributed.init_process_group(
183 | backend='nccl',
184 | world_size=int(os.getenv('WORLD_SIZE', '1')),
185 | rank=int(os.getenv('RANK', '0')),
186 | )
187 |
188 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
189 | os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('LOCAL_RANK', "0")
190 |
191 | model_path = args.checkpoint
192 | model_name = get_model_name_from_path(model_path)
193 |
194 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
195 | tokenizer.padding_side = 'left'
196 | if not hasattr(tokenizer, 'pad_token_id'):
197 | tokenizer.pad_token_id = tokenizer.eos_token_id
198 |
199 | prompt = "USER: <|image|>{}\n{}\n{}\nAnswer with the option’s letter from the given choices directly. ASSISTANT: "
200 |
201 | random.seed(args.seed)
202 | dataset = VQADataset(
203 | test=ds_collections[args.dataset]['annotation'],
204 | prompt=prompt,
205 | image_processor=image_processor,
206 | )
207 |
208 | dataloader = torch.utils.data.DataLoader(
209 | dataset=dataset,
210 | sampler=InferenceSampler(len(dataset)),
211 | batch_size=args.batch_size,
212 | num_workers=args.num_workers,
213 | pin_memory=True,
214 | drop_last=False,
215 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
216 | )
217 |
218 | outputs = []
219 | for _, (image_tensor, input_ids, attention_mask, indices) in enumerate(tqdm(dataloader)):
220 | pred = model.generate(
221 | input_ids=input_ids.cuda(),
222 | attention_mask=attention_mask.cuda(),
223 | images=image_tensor.to(dtype=model.dtype).cuda(),
224 | do_sample=False,
225 | num_beams=1,
226 | max_new_tokens=ds_collections[args.dataset]['max_new_tokens'],
227 | min_new_tokens=1,
228 | length_penalty=1,
229 | num_return_sequences=1,
230 | output_hidden_states=True,
231 | use_cache=True,
232 | )
233 | answers = [
234 | tokenizer.decode(_[input_ids.size(1):].cpu(),
235 | skip_special_tokens=True).strip() for _ in pred
236 | ]
237 |
238 | for index, answer in zip(indices, answers):
239 | outputs.append({
240 | 'index': index,
241 | 'prediction': answer,
242 | })
243 |
244 | torch.distributed.barrier()
245 |
246 | world_size = torch.distributed.get_world_size()
247 | merged_outputs = [None for _ in range(world_size)]
248 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
249 |
250 | merged_outputs = [json.loads(_) for _ in merged_outputs]
251 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
252 |
253 | if torch.distributed.get_rank() == 0:
254 | print(f"Evaluating {args.dataset} ...")
255 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
256 | results_file = f'{args.dataset}_{time_prefix}_s{args.seed}.json'
257 | json.dump(merged_outputs, open(results_file, 'w'), ensure_ascii=False)
258 |
259 | mapped_result = mapping_to_annotation(merged_outputs, pd.read_csv(ds_collections[args.dataset]['raw_file'], sep='\t'))
260 |
261 | submission_res = generate_submission_file(merged_outputs, pd.read_csv(ds_collections[args.dataset]['raw_file'], sep='\t'))
262 | res_df = pd.DataFrame(submission_res)
263 | metrics_file = f'{args.dataset}_{time_prefix}_s{args.seed}_submission.xlsx'
264 | res_df.to_excel(metrics_file, index=False)
265 |
266 | torch.distributed.barrier()
267 |
--------------------------------------------------------------------------------
/mplug_owl2/evaluate/evaluate_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 | from typing import Optional
9 |
10 | import torch
11 | from tqdm import tqdm
12 | from PIL import Image
13 |
14 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
15 | from mplug_owl2.conversation import conv_templates, SeparatorStyle
16 | from mplug_owl2.model.builder import load_pretrained_model
17 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
18 |
19 | from vqa import VQA
20 | from vqa_eval import VQAEval
21 |
22 | ds_collections = {
23 | 'vqav2_val': {
24 | 'train': 'data/vqav2/vqav2_train.jsonl',
25 | 'test': 'data/vqav2/vqav2_val.jsonl',
26 | 'question': 'data/vqav2/v2_OpenEnded_mscoco_val2014_questions.json',
27 | 'annotation': 'data/vqav2/v2_mscoco_val2014_annotations.json',
28 | 'metric': 'vqa_score',
29 | 'max_new_tokens': 10,
30 | },
31 | 'vqav2_testdev': {
32 | 'train': 'data/vqav2/vqav2_train.jsonl',
33 | 'test': 'data/vqav2/vqav2_testdev.jsonl',
34 | 'metric': None,
35 | 'max_new_tokens': 10,
36 | },
37 | 'okvqa_val': {
38 | 'train': 'data/okvqa/okvqa_train.jsonl',
39 | 'test': 'data/okvqa/okvqa_val.jsonl',
40 | 'question': 'data/okvqa/OpenEnded_mscoco_val2014_questions.json',
41 | 'annotation': 'data/okvqa/mscoco_val2014_annotations.json',
42 | 'metric': 'vqa_score',
43 | 'max_new_tokens': 10,
44 | },
45 | 'textvqa_val': {
46 | 'train': 'data/textvqa/textvqa_train.jsonl',
47 | 'test': 'data/textvqa/textvqa_val.jsonl',
48 | 'question': 'data/textvqa/textvqa_val_questions.json',
49 | 'annotation': 'data/textvqa/textvqa_val_annotations.json',
50 | 'metric': 'vqa_score',
51 | 'max_new_tokens': 10,
52 | },
53 | }
54 |
55 | def collate_fn(batches, tokenizer):
56 |
57 | questions = [_['question'] for _ in batches]
58 | question_ids = [_['question_id'] for _ in batches]
59 | annotations = [_['annotation'] for _ in batches]
60 |
61 | image_tensor = [_['image_tensor'] for _ in batches]
62 |
63 | input_ids = []
64 | for input_text in questions:
65 | input_ids.append(tokenizer_image_token(input_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').tolist())
66 | input_tokens_max_length = max([len(x) for x in input_ids])
67 | pad_token_id = tokenizer.pad_token_id
68 |
69 | input_ids = [([pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left
70 | input_ids = torch.LongTensor(input_ids)
71 | attention_mask = 1 - input_ids.eq(pad_token_id).long()
72 |
73 | image_tensor = torch.cat(image_tensor, dim=0)
74 | return question_ids, image_tensor, input_ids, attention_mask, annotations
75 |
76 |
77 | class VQADataset(torch.utils.data.Dataset):
78 |
79 | def __init__(self, train, test, prompt, image_processor, few_shot):
80 | self.test = json.load(open(test))
81 | self.prompt = prompt
82 | self.image_processor = image_processor
83 |
84 | self.few_shot = few_shot
85 | if few_shot > 0:
86 | self.train = open(train).readlines()
87 |
88 | def __len__(self):
89 | return len(self.test)
90 |
91 | def __getitem__(self, idx):
92 | data = self.test[idx]
93 | image, question, question_id, annotation = data['image'], data[
94 | 'question'], data['question_id'], data.get('answer', None)
95 | image = Image.open(image).convert('RGB')
96 | max_edge = max(image.size)
97 | image = image.resize((max_edge, max_edge)) # Resize here for best performance
98 | image_tensor = process_images([image], self.image_processor)
99 |
100 | return {
101 | 'image_tensor': image_tensor,
102 | 'question': self.prompt.format(question),
103 | 'question_id': question_id,
104 | 'annotation': annotation
105 | }
106 |
107 |
108 | class InferenceSampler(torch.utils.data.sampler.Sampler):
109 |
110 | def __init__(self, size):
111 | self._size = int(size)
112 | assert size > 0
113 | self._rank = torch.distributed.get_rank()
114 | self._world_size = torch.distributed.get_world_size()
115 | self._local_indices = self._get_local_indices(size, self._world_size,
116 | self._rank)
117 |
118 | @staticmethod
119 | def _get_local_indices(total_size, world_size, rank):
120 | shard_size = total_size // world_size
121 | left = total_size % world_size
122 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
123 |
124 | begin = sum(shard_sizes[:rank])
125 | end = min(sum(shard_sizes[:rank + 1]), total_size)
126 | return range(begin, end)
127 |
128 | def __iter__(self):
129 | yield from self._local_indices
130 |
131 | def __len__(self):
132 | return len(self._local_indices)
133 |
134 |
135 | if __name__ == '__main__':
136 |
137 | parser = argparse.ArgumentParser()
138 | parser.add_argument('--checkpoint', type=str, default='')
139 | parser.add_argument('--dataset', type=str, default='textvqa_val')
140 | parser.add_argument('--batch-size', type=int, default=1)
141 | parser.add_argument('--num-workers', type=int, default=1)
142 | parser.add_argument('--few-shot', type=int, default=0)
143 | parser.add_argument('--seed', type=int, default=0)
144 | args = parser.parse_args()
145 |
146 | torch.distributed.init_process_group(
147 | backend='nccl',
148 | world_size=int(os.getenv('WORLD_SIZE', '1')),
149 | rank=int(os.getenv('RANK', '0')),
150 | )
151 |
152 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
153 |
154 | model_path = args.checkpoint
155 | model_name = get_model_name_from_path(model_path)
156 |
157 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
158 | tokenizer.padding_side = 'left'
159 | if not hasattr(tokenizer, 'pad_token_id'):
160 | tokenizer.pad_token_id = tokenizer.eos_token_id
161 |
162 | prompt = 'USER: <|image|>{} Answer the question using a single word or phrase. ASSISTANT: '
163 | answer_processor = EvalAIAnswerProcessor()
164 | random.seed(args.seed)
165 | dataset = VQADataset(
166 | train=ds_collections[args.dataset]['train'],
167 | test=ds_collections[args.dataset]['test'],
168 | prompt=prompt,
169 | image_processor=image_processor,
170 | few_shot=args.few_shot,
171 | )
172 |
173 | dataloader = torch.utils.data.DataLoader(
174 | dataset=dataset,
175 | sampler=InferenceSampler(len(dataset)),
176 | batch_size=args.batch_size,
177 | num_workers=args.num_workers,
178 | pin_memory=True,
179 | drop_last=False,
180 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
181 | )
182 |
183 | outputs = []
184 | for _, (question_ids, image_tensor, input_ids, attention_mask,
185 | annotations) in enumerate(tqdm(dataloader)):
186 | pred = model.generate(
187 | input_ids=input_ids.cuda(),
188 | attention_mask=attention_mask.cuda(),
189 | images=image_tensor.to(dtype=model.dtype).cuda(),
190 | do_sample=False,
191 | num_beams=1,
192 | max_new_tokens=ds_collections[args.dataset]['max_new_tokens'],
193 | min_new_tokens=1,
194 | length_penalty=1,
195 | num_return_sequences=1,
196 | output_hidden_states=True,
197 | use_cache=True,
198 | )
199 | answers = [
200 | tokenizer.decode(_[input_ids.size(1):].cpu(),
201 | skip_special_tokens=True).strip() for _ in pred
202 | ]
203 |
204 | for question_id, answer, annotation in zip(question_ids, answers,
205 | annotations):
206 | if args.dataset in ['vqav2_val', 'okvqa_val', 'textvqa_val']:
207 | outputs.append({
208 | 'question_id': question_id,
209 | 'answer': answer,
210 | })
211 | elif args.dataset == 'vqav2_testdev':
212 | outputs.append({
213 | 'question_id': question_id,
214 | 'answer': answer_processor(answer),
215 | })
216 | else:
217 | raise NotImplementedError
218 |
219 | torch.distributed.barrier()
220 |
221 | world_size = torch.distributed.get_world_size()
222 | merged_outputs = [None for _ in range(world_size)]
223 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
224 |
225 | merged_outputs = [json.loads(_) for _ in merged_outputs]
226 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
227 |
228 | if torch.distributed.get_rank() == 0:
229 | print(f"Evaluating {args.dataset} ...")
230 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
231 | results_file = f'{args.dataset}_{time_prefix}_fs{args.few_shot}_s{args.seed}.json'
232 | json.dump(merged_outputs, open(results_file, 'w', encoding='utf-8'), ensure_ascii=False)
233 |
234 | if ds_collections[args.dataset]['metric'] == 'vqa_score':
235 | vqa = VQA(ds_collections[args.dataset]['annotation'],
236 | ds_collections[args.dataset]['question'])
237 | results = vqa.loadRes(
238 | resFile=results_file,
239 | quesFile=ds_collections[args.dataset]['question'])
240 | vqa_scorer = VQAEval(vqa, results, n=2)
241 | vqa_scorer.evaluate()
242 |
243 | print(vqa_scorer.accuracy)
244 | torch.distributed.barrier()
245 |
--------------------------------------------------------------------------------
/mplug_owl2/evaluate/mmbench_converter.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import io
3 | import base64
4 | import json
5 | from PIL import Image
6 |
7 | '''
8 | This scripts convert mmbench_dev tsv file to jsonl
9 | '''
10 |
11 | datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t')
12 |
13 | global_choices = ['A', 'B', 'C', 'D']
14 |
15 | def decode_base64_to_image(base64_string):
16 | image_data = base64.b64decode(base64_string)
17 | image = Image.open(io.BytesIO(image_data))
18 | return image
19 |
20 |
21 | with open('./data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.jsonl', 'w') as f:
22 | for idx in range(len(datas)):
23 | data = datas.iloc[idx]
24 |
25 | index = int(data['index'])
26 | question = data['question']
27 | hint = data['hint'] if not pd.isna(data['hint']) else 'N/A'
28 |
29 | choices = []
30 | for opt in global_choices:
31 | if pd.isna(data[opt]):
32 | continue
33 | choices.append(data[opt])
34 |
35 | answer = global_choices.index(data['answer'])
36 |
37 | image = decode_base64_to_image(data['image'])
38 | image.save("data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index)
39 |
40 | f.write(json.dumps({
41 | "index": index,
42 | "image": "data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index,
43 | "hint": hint,
44 | "question": question,
45 | "choices": choices,
46 | "answer": answer,
47 | }) + "\n")
--------------------------------------------------------------------------------
/mplug_owl2/evaluate/vqa.py:
--------------------------------------------------------------------------------
1 | """Copyright (c) 2022, salesforce.com, inc.
2 |
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | __author__ = 'aagrawal'
9 | __version__ = '0.9'
10 |
11 | # Interface for accessing the VQA dataset.
12 |
13 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
14 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
15 |
16 | # The following functions are defined:
17 | # VQA - VQA class that loads VQA annotation file and prepares data structures.
18 | # getQuesIds - Get question ids that satisfy given filter conditions.
19 | # getImgIds - Get image ids that satisfy given filter conditions.
20 | # loadQA - Load questions and answers with the specified question ids.
21 | # showQA - Display the specified questions and answers.
22 | # loadRes - Load result file and create result object.
23 |
24 | # Help on each function can be accessed by: "help(COCO.function)"
25 |
26 | import copy
27 | import datetime
28 | import json
29 |
30 |
31 | class VQA:
32 |
33 | def __init__(self, annotation_file=None, question_file=None):
34 | """Constructor of VQA helper class for reading and visualizing
35 | questions and answers.
36 |
37 | :param annotation_file (str): location of VQA annotation file
38 | :return:
39 | """
40 | # load dataset
41 | self.dataset = {}
42 | self.questions = {}
43 | self.qa = {}
44 | self.qqa = {}
45 | self.imgToQA = {}
46 | if not annotation_file == None and not question_file == None:
47 | print('loading VQA annotations and questions into memory...')
48 | time_t = datetime.datetime.utcnow()
49 | dataset = json.load(open(annotation_file, 'r'))
50 | questions = json.load(open(question_file, 'r'))
51 | self.dataset = dataset
52 | self.questions = questions
53 | self.createIndex()
54 |
55 | def createIndex(self):
56 | # create index
57 | print('creating index...')
58 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
59 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
60 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
61 | for ann in self.dataset['annotations']:
62 | imgToQA[ann['image_id']] += [ann]
63 | qa[ann['question_id']] = ann
64 | for ques in self.questions['questions']:
65 | qqa[ques['question_id']] = ques
66 | print('index created!')
67 |
68 | # create class members
69 | self.qa = qa
70 | self.qqa = qqa
71 | self.imgToQA = imgToQA
72 |
73 | def info(self):
74 | """Print information about the VQA annotation file.
75 |
76 | :return:
77 | """
78 | for key, value in self.datset['info'].items():
79 | print('%s: %s' % (key, value))
80 |
81 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
82 | """Get question ids that satisfy given filter conditions. default skips
83 | that filter.
84 |
85 | :param imgIds (int array) : get question ids for given imgs
86 | quesTypes (str array) : get question ids for given question types
87 | ansTypes (str array) : get question ids for given answer types
88 | :return: ids (int array) : integer array of question ids
89 | """
90 | imgIds = imgIds if type(imgIds) == list else [imgIds]
91 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
92 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
93 |
94 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
95 | anns = self.dataset['annotations']
96 | else:
97 | if not len(imgIds) == 0:
98 | anns = sum(
99 | [
100 | self.imgToQA[imgId]
101 | for imgId in imgIds if imgId in self.imgToQA
102 | ],
103 | [],
104 | )
105 | else:
106 | anns = self.dataset['annotations']
107 | anns = (anns if len(quesTypes) == 0 else
108 | [ann for ann in anns if ann['question_type'] in quesTypes])
109 | anns = (anns if len(ansTypes) == 0 else
110 | [ann for ann in anns if ann['answer_type'] in ansTypes])
111 | ids = [ann['question_id'] for ann in anns]
112 | return ids
113 |
114 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
115 | """Get image ids that satisfy given filter conditions. default skips
116 | that filter.
117 |
118 | :param quesIds (int array) : get image ids for given question ids
119 | quesTypes (str array) : get image ids for given question types
120 | ansTypes (str array) : get image ids for given answer types
121 | :return: ids (int array) : integer array of image ids
122 | """
123 | quesIds = quesIds if type(quesIds) == list else [quesIds]
124 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
125 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
126 |
127 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
128 | anns = self.dataset['annotations']
129 | else:
130 | if not len(quesIds) == 0:
131 | anns = sum([
132 | self.qa[quesId] for quesId in quesIds if quesId in self.qa
133 | ], [])
134 | else:
135 | anns = self.dataset['annotations']
136 | anns = (anns if len(quesTypes) == 0 else
137 | [ann for ann in anns if ann['question_type'] in quesTypes])
138 | anns = (anns if len(ansTypes) == 0 else
139 | [ann for ann in anns if ann['answer_type'] in ansTypes])
140 | ids = [ann['image_id'] for ann in anns]
141 | return ids
142 |
143 | def loadQA(self, ids=[]):
144 | """Load questions and answers with the specified question ids.
145 |
146 | :param ids (int array) : integer ids specifying question ids
147 | :return: qa (object array) : loaded qa objects
148 | """
149 | if type(ids) == list:
150 | return [self.qa[id] for id in ids]
151 | elif type(ids) == int:
152 | return [self.qa[ids]]
153 |
154 | def showQA(self, anns):
155 | """Display the specified annotations.
156 |
157 | :param anns (array of object): annotations to display
158 | :return: None
159 | """
160 | if len(anns) == 0:
161 | return 0
162 | for ann in anns:
163 | quesId = ann['question_id']
164 | print('Question: %s' % (self.qqa[quesId]['question']))
165 | for ans in ann['answers']:
166 | print('Answer %d: %s' % (ans['answer_id'], ans['answer']))
167 |
168 | def loadRes(self, resFile, quesFile):
169 | """Load result file and return a result object.
170 |
171 | :param resFile (str) : file name of result file
172 | :return: res (obj) : result api object
173 | """
174 | res = VQA()
175 | res.questions = json.load(open(quesFile))
176 | res.dataset['info'] = copy.deepcopy(self.questions['info'])
177 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
178 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
179 | res.dataset['data_subtype'] = copy.deepcopy(
180 | self.questions['data_subtype'])
181 | res.dataset['license'] = copy.deepcopy(self.questions['license'])
182 |
183 | print('Loading and preparing results... ')
184 | time_t = datetime.datetime.utcnow()
185 | anns = json.load(open(resFile))
186 | assert type(anns) == list, 'results is not an array of objects'
187 | annsQuesIds = [ann['question_id'] for ann in anns]
188 | assert set(annsQuesIds) == set(
189 | self.getQuesIds()
190 | ), 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
191 | for ann in anns:
192 | quesId = ann['question_id']
193 | if res.dataset['task_type'] == 'Multiple Choice':
194 | assert (
195 | ann['answer'] in self.qqa[quesId]['multiple_choices']
196 | ), 'predicted answer is not one of the multiple choices'
197 | qaAnn = self.qa[quesId]
198 | ann['image_id'] = qaAnn['image_id']
199 | ann['question_type'] = qaAnn['question_type']
200 | ann['answer_type'] = qaAnn['answer_type']
201 | print('DONE (t=%0.2fs)' %
202 | ((datetime.datetime.utcnow() - time_t).total_seconds()))
203 |
204 | res.dataset['annotations'] = anns
205 | res.createIndex()
206 | return res
--------------------------------------------------------------------------------
/mplug_owl2/local_serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/local_serve/__init__.py
--------------------------------------------------------------------------------
/mplug_owl2/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg
--------------------------------------------------------------------------------
/mplug_owl2/local_serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/local_serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/mplug_owl2/local_serve/model_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | A model worker executes the model.
3 | """
4 | import argparse
5 | import asyncio
6 | import json
7 | import time
8 | import threading
9 | import uuid
10 |
11 | import requests
12 | import torch
13 | from functools import partial
14 |
15 | from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
16 | from mplug_owl2.utils import (build_logger, server_error_msg,
17 | pretty_print_semaphore)
18 | from mplug_owl2.model.builder import load_pretrained_model
19 | from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
20 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
21 | from transformers import TextIteratorStreamer
22 | from threading import Thread
23 |
24 | GB = 1 << 30
25 |
26 | worker_id = str(uuid.uuid4())[:6]
27 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
28 |
29 | class ModelWorker:
30 | def __init__(self, model_path, model_base, model_name, load_8bit, load_4bit, device):
31 | self.worker_id = worker_id
32 | if model_path.endswith("/"):
33 | model_path = model_path[:-1]
34 | if model_name is None:
35 | model_paths = model_path.split("/")
36 | if model_paths[-1].startswith('checkpoint-'):
37 | self.model_name = model_paths[-2] + "_" + model_paths[-1]
38 | else:
39 | self.model_name = model_paths[-1]
40 | else:
41 | self.model_name = model_name
42 |
43 | self.device = device
44 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
45 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
46 | model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
47 | self.is_multimodal = True
48 |
49 | @torch.inference_mode()
50 | def generate_stream(self, params):
51 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
52 |
53 | prompt = params["prompt"]
54 | ori_prompt = prompt
55 | images = params.get("images", None)
56 | num_image_tokens = 0
57 | if images is not None and len(images) > 0 and self.is_multimodal:
58 | if len(images) > 0:
59 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
60 | raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
61 |
62 | images = [load_image_from_base64(image) for image in images]
63 | images = process_images(images, image_processor, model.config)
64 |
65 | if type(images) is list:
66 | images = [image.to(self.model.device, dtype=torch.float16) for image in images]
67 | else:
68 | images = images.to(self.model.device, dtype=torch.float16)
69 |
70 | replace_token = DEFAULT_IMAGE_TOKEN
71 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
72 |
73 | num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
74 | else:
75 | images = None
76 | image_args = {"images": images}
77 | else:
78 | images = None
79 | image_args = {}
80 |
81 | temperature = float(params.get("temperature", 1.0))
82 | top_p = float(params.get("top_p", 1.0))
83 | max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
84 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
85 | stop_str = params.get("stop", None)
86 | do_sample = True if temperature > 0.001 else False
87 |
88 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
89 | keywords = [stop_str]
90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
92 |
93 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
94 |
95 | if max_new_tokens < 1:
96 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
97 | return
98 |
99 | thread = Thread(target=model.generate, kwargs=dict(
100 | inputs=input_ids,
101 | do_sample=do_sample,
102 | temperature=temperature,
103 | top_p=top_p,
104 | max_new_tokens=max_new_tokens,
105 | streamer=streamer,
106 | stopping_criteria=[stopping_criteria],
107 | use_cache=True,
108 | **image_args
109 | ))
110 | thread.start()
111 |
112 | generated_text = ori_prompt
113 | for new_text in streamer:
114 | generated_text += new_text
115 | if generated_text.endswith(stop_str):
116 | generated_text = generated_text[:-len(stop_str)]
117 | yield json.dumps({"text": generated_text, "error_code": 0}).encode()
118 |
119 | def generate_stream_gate(self, params):
120 | try:
121 | for x in self.generate_stream(params):
122 | yield x
123 | except ValueError as e:
124 | print("Caught ValueError:", e)
125 | ret = {
126 | "text": server_error_msg,
127 | "error_code": 1,
128 | }
129 | yield json.dumps(ret).encode()
130 | except torch.cuda.CudaError as e:
131 | print("Caught torch.cuda.CudaError:", e)
132 | ret = {
133 | "text": server_error_msg,
134 | "error_code": 1,
135 | }
136 | yield json.dumps(ret).encode()
137 | except Exception as e:
138 | print("Caught Unknown Error", e)
139 | ret = {
140 | "text": server_error_msg,
141 | "error_code": 1,
142 | }
143 | yield json.dumps(ret).encode()
--------------------------------------------------------------------------------
/mplug_owl2/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 |
5 | import torch
6 | from transformers import StoppingCriteria
7 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
8 | from icecream import ic
9 |
10 |
11 | def load_image_from_base64(image):
12 | return Image.open(BytesIO(base64.b64decode(image)))
13 |
14 |
15 | def expand2square(pil_img, background_color):
16 | width, height = pil_img.size
17 | if width == height:
18 | return pil_img
19 | elif width > height:
20 | result = Image.new(pil_img.mode, (width, width), background_color)
21 | result.paste(pil_img, (0, (width - height) // 2))
22 | return result
23 | else:
24 | result = Image.new(pil_img.mode, (height, height), background_color)
25 | result.paste(pil_img, ((height - width) // 2, 0))
26 | return result
27 |
28 |
29 | def process_images(images, image_processor, model_cfg=None):
30 | if model_cfg is not None:
31 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
32 | else:
33 | image_aspect_ratio = 'resize'
34 | new_images = []
35 | if image_aspect_ratio == 'pad':
36 | for image in images:
37 | image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
38 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
39 | new_images.append(image)
40 | elif image_aspect_ratio == 'resize':
41 | for image in images:
42 | max_edge = max(image.size)
43 | image = image.resize((max_edge, max_edge))
44 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
45 | new_images.append(image)
46 | else:
47 | return image_processor(images, return_tensors='pt')['pixel_values']
48 | if all(x.shape == new_images[0].shape for x in new_images):
49 | new_images = torch.stack(new_images, dim=0)
50 | return new_images
51 |
52 |
53 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
54 | prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in
55 | prompt.split(DEFAULT_IMAGE_TOKEN)]
56 |
57 | def insert_separator(X, sep):
58 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
59 |
60 | input_ids = []
61 | offset = 0
62 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
63 | offset = 1
64 | input_ids.append(prompt_chunks[0][0])
65 |
66 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
67 | input_ids.extend(x[offset:])
68 |
69 | if return_tensors is not None:
70 | if return_tensors == 'pt':
71 | return torch.tensor(input_ids, dtype=torch.long)
72 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
73 | return input_ids
74 |
75 |
76 | def get_model_name_from_path(model_path):
77 | model_path = model_path.strip("/")
78 | model_paths = model_path.split("/")
79 | if model_paths[-1].startswith('checkpoint-'):
80 | return model_paths[-2] + "_" + model_paths[-1]
81 | else:
82 | return model_paths[-1]
83 |
84 |
85 | class KeywordsStoppingCriteria(StoppingCriteria):
86 | def __init__(self, keywords, tokenizer, input_ids):
87 | self.keywords = keywords
88 | self.keyword_ids = []
89 | self.max_keyword_len = 0
90 | for keyword in keywords:
91 | cur_keyword_ids = tokenizer(keyword).input_ids
92 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
93 | cur_keyword_ids = cur_keyword_ids[1:]
94 | if len(cur_keyword_ids) > self.max_keyword_len:
95 | self.max_keyword_len = len(cur_keyword_ids)
96 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
97 | self.tokenizer = tokenizer
98 | self.start_len = input_ids.shape[1]
99 |
100 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
101 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
102 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
103 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
104 | for keyword_id in self.keyword_ids:
105 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
106 | return True
107 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
108 | for keyword in self.keywords:
109 | if keyword in outputs:
110 | return True
111 | return False
112 |
--------------------------------------------------------------------------------
/mplug_owl2/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM, MPLUGOwl2QWenForCausalLM
2 | from .configuration_mplug_owl2 import MPLUGOwl2Config,MPLUGOwl2QwenConfig
3 |
--------------------------------------------------------------------------------
/mplug_owl2/model/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import warnings
18 | import shutil
19 |
20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21 | from transformers.models.clip.image_processing_clip import CLIPImageProcessor
22 | import torch
23 | from extensions.sd_smartprocess.mplug_owl2.model import *
24 | from icecream import ic
25 |
26 | from extensions.sd_smartprocess.mplug_owl2 import MPLUGOwl2LlamaForCausalLM
27 | from extensions.sd_smartprocess.mplug_owl2.model import MPLUGOwl2QWenForCausalLM
28 |
29 |
30 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto",
31 | device="cuda", **kwargs):
32 | kwargs = {"device_map": device_map, "ignore_mismatched_sizes": False, **kwargs}
33 |
34 | if device != "cuda":
35 | kwargs['device_map'] = {"": device}
36 |
37 | if load_8bit:
38 | kwargs['load_in_8bit'] = True
39 | elif load_4bit:
40 | kwargs['load_in_4bit'] = True
41 | kwargs['quantization_config'] = BitsAndBytesConfig(
42 | load_in_4bit=True,
43 | bnb_4bit_compute_dtype=torch.float16,
44 | bnb_4bit_use_double_quant=True,
45 | bnb_4bit_quant_type='nf4'
46 | )
47 | else:
48 | kwargs['torch_dtype'] = torch.float16
49 | if 'mplug_owl2' in model_name.lower():
50 | # Load LLaVA2 model
51 | if 'lora' in model_name.lower() and model_base is None:
52 | warnings.warn(
53 | 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
54 | if 'lora' in model_name.lower() and model_base is not None:
55 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
56 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
57 | print('Loading mPLUG-Owl2 from base model...')
58 | if 'mplug_owl2_1' in model_name.lower():
59 | model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
60 | config=lora_cfg_pretrained, **kwargs)
61 | else:
62 | model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
63 | config=lora_cfg_pretrained, **kwargs)
64 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
65 | if model.lm_head.weight.shape[0] != token_num:
66 | model.lm_head.weight = torch.nn.Parameter(
67 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
68 | model.model.embed_tokens.weight = torch.nn.Parameter(
69 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
70 |
71 | print('Loading additional mPLUG-Owl2 weights...')
72 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
73 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'),
74 | map_location='cpu')
75 | else:
76 | # this is probably from HF Hub
77 | from huggingface_hub import hf_hub_download
78 | def load_from_hf(repo_id, filename, subfolder=None):
79 | cache_file = hf_hub_download(
80 | repo_id=repo_id,
81 | filename=filename,
82 | subfolder=subfolder)
83 | return torch.load(cache_file, map_location='cpu')
84 |
85 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
86 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
87 | non_lora_trainables.items()}
88 | if any(k.startswith('model.model.') for k in non_lora_trainables):
89 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in
90 | non_lora_trainables.items()}
91 | model.load_state_dict(non_lora_trainables, strict=False)
92 |
93 | from peft import PeftModel
94 | print('Loading LoRA weights...')
95 | model = PeftModel.from_pretrained(model, model_path)
96 | print('Merging LoRA weights...')
97 | model = model.merge_and_unload()
98 | print('Model is loaded...')
99 | elif model_base is not None:
100 | # this may be mm projector only
101 | print('Loading mPLUG-Owl2 from base model...')
102 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
103 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
104 | if 'mplug_owl2_1' in model_name.lower():
105 | model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
106 | config=cfg_pretrained, **kwargs)
107 | else:
108 | model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
109 | config=cfg_pretrained, **kwargs)
110 | else:
111 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
112 | if 'mplug_owl2_1' in model_name.lower():
113 | model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
114 | else:
115 | model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
116 | else:
117 | # Load language model
118 | if model_base is not None:
119 | # PEFT model
120 | from peft import PeftModel
121 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True)
122 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
123 | print(f"Loading LoRA weights from {model_path}")
124 | model = PeftModel.from_pretrained(model, model_path)
125 | print(f"Merging weights")
126 | model = model.merge_and_unload()
127 | print('Convert to FP16...')
128 | model.to(torch.float16)
129 | else:
130 | use_fast = False
131 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
132 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
133 |
134 | vision_tower = model.get_model().vision_model
135 | # vision_tower.to(device=device, dtype=torch.float16)
136 | image_processor = CLIPImageProcessor.from_pretrained(model_path)
137 |
138 | if hasattr(model.config, "max_sequence_length"):
139 | context_len = model.config.max_sequence_length
140 | else:
141 | context_len = 2048
142 |
143 | return tokenizer, model, image_processor, context_len
144 |
--------------------------------------------------------------------------------
/mplug_owl2/model/configuration_qwen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from transformers import PretrainedConfig
7 |
8 |
9 | class QWenConfig(PretrainedConfig):
10 | model_type = "qwen"
11 | keys_to_ignore_at_inference = ["past_key_values"]
12 |
13 | def __init__(
14 | self,
15 | multiway=False,
16 | vocab_size=151936,
17 | hidden_size=4096,
18 | num_hidden_layers=32,
19 | num_attention_heads=32,
20 | emb_dropout_prob=0.0,
21 | attn_dropout_prob=0.0,
22 | layer_norm_epsilon=1e-6,
23 | initializer_range=0.02,
24 | max_position_embeddings=8192,
25 | scale_attn_weights=True,
26 | use_cache=True,
27 | bf16=False,
28 | fp16=False,
29 | fp32=False,
30 | kv_channels=128,
31 | rotary_pct=1.0,
32 | rotary_emb_base=10000,
33 | use_dynamic_ntk=True,
34 | use_logn_attn=True,
35 | use_flash_attn="auto",
36 | intermediate_size=22016,
37 | no_bias=True,
38 | tie_word_embeddings=False,
39 | use_cache_quantization=False,
40 | use_cache_kernel=False,
41 | softmax_in_fp32=False,
42 | **kwargs,
43 | ):
44 | self.multiway = multiway
45 | self.vocab_size = vocab_size
46 | self.hidden_size = hidden_size
47 | self.intermediate_size = intermediate_size
48 | self.num_hidden_layers = num_hidden_layers
49 | self.num_attention_heads = num_attention_heads
50 | self.emb_dropout_prob = emb_dropout_prob
51 | self.attn_dropout_prob = attn_dropout_prob
52 | self.layer_norm_epsilon = layer_norm_epsilon
53 | self.initializer_range = initializer_range
54 | self.scale_attn_weights = scale_attn_weights
55 | self.use_cache = use_cache
56 | self.max_position_embeddings = max_position_embeddings
57 | self.bf16 = bf16
58 | self.fp16 = fp16
59 | self.fp32 = fp32
60 | self.kv_channels = kv_channels
61 | self.rotary_pct = rotary_pct
62 | self.rotary_emb_base = rotary_emb_base
63 | self.use_dynamic_ntk = use_dynamic_ntk
64 | self.use_logn_attn = use_logn_attn
65 | self.use_flash_attn = use_flash_attn
66 | self.no_bias = no_bias
67 | self.use_cache_quantization = use_cache_quantization
68 | self.use_cache_kernel = use_cache_kernel
69 | self.softmax_in_fp32 = softmax_in_fp32
70 | super().__init__(
71 | tie_word_embeddings=tie_word_embeddings,
72 | **kwargs
73 | )
--------------------------------------------------------------------------------
/mplug_owl2/model/modeling_attn_mask_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import List, Optional, Tuple, Union
15 |
16 | import torch
17 |
18 |
19 | class AttentionMaskConverter:
20 | """
21 | A utility attention mask class that allows one to:
22 | - Create a causal 4d mask
23 | - Create a causal 4d mask with slided window
24 | - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
25 | key_value_length) that can be multiplied with attention scores
26 |
27 | Parameters:
28 | is_causal (`bool`):
29 | Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
30 |
31 | sliding_window (`int`, *optional*):
32 | Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
33 | """
34 |
35 | def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
36 | self.is_causal = is_causal
37 | self.sliding_window = sliding_window
38 |
39 | if self.sliding_window is not None and self.sliding_window <= 0:
40 | raise ValueError(
41 | f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
42 | )
43 |
44 | def to_causal_4d(
45 | self,
46 | batch_size: int,
47 | query_length: int,
48 | key_value_length: int,
49 | dtype: torch.dtype = torch.float32,
50 | device: Union[torch.device, "str"] = "cpu",
51 | ) -> torch.Tensor:
52 | """
53 | Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
54 | bias to upper right hand triangular matrix (causal mask).
55 | """
56 | if not self.is_causal:
57 | raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
58 |
59 | # If shape is not cached, create a new causal mask and cache it
60 | input_shape = (batch_size, query_length)
61 | past_key_values_length = key_value_length - query_length
62 |
63 | # create causal mask
64 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
65 | causal_4d_mask = None
66 | if input_shape[-1] > 1 or self.sliding_window is not None:
67 | causal_4d_mask = self._make_causal_mask(
68 | input_shape,
69 | dtype,
70 | device=device,
71 | past_key_values_length=past_key_values_length,
72 | sliding_window=self.sliding_window,
73 | )
74 |
75 | return causal_4d_mask
76 |
77 | def to_4d(
78 | self,
79 | attention_mask_2d: torch.Tensor,
80 | query_length: int,
81 | key_value_length: Optional[int] = None,
82 | dtype: torch.dtype = torch.float32,
83 | ) -> torch.Tensor:
84 | """
85 | Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
86 | key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
87 | causal, a causal mask will be added.
88 | """
89 | input_shape = (attention_mask_2d.shape[0], query_length)
90 |
91 | # create causal mask
92 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
93 | causal_4d_mask = None
94 | if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
95 | if key_value_length is None:
96 | raise ValueError(
97 | "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
98 | )
99 |
100 | past_key_values_length = key_value_length - query_length
101 | causal_4d_mask = self._make_causal_mask(
102 | input_shape,
103 | dtype,
104 | device=attention_mask_2d.device,
105 | past_key_values_length=past_key_values_length,
106 | sliding_window=self.sliding_window,
107 | )
108 | elif self.sliding_window is not None:
109 | raise NotImplementedError("Sliding window is currently only implemented for causal masking")
110 |
111 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
112 | expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
113 | attention_mask_2d.device
114 | )
115 | expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
116 |
117 | return expanded_4d_mask
118 |
119 | @staticmethod
120 | def _make_causal_mask(
121 | input_ids_shape: torch.Size,
122 | dtype: torch.dtype,
123 | device: torch.device,
124 | past_key_values_length: int = 0,
125 | sliding_window: Optional[int] = None,
126 | ):
127 | """
128 | Make causal mask used for bi-directional self-attention.
129 | """
130 | bsz, tgt_len = input_ids_shape
131 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
132 | mask_cond = torch.arange(mask.size(-1), device=device)
133 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
134 |
135 | mask = mask.to(dtype)
136 |
137 | if past_key_values_length > 0:
138 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
139 |
140 | # add lower triangular sliding window mask if necessary
141 | if sliding_window is not None:
142 | diagonal = past_key_values_length - sliding_window + 1
143 |
144 | context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
145 | mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
146 |
147 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
148 |
149 | @staticmethod
150 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
151 | """
152 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
153 | """
154 | bsz, src_len = mask.size()
155 | tgt_len = tgt_len if tgt_len is not None else src_len
156 |
157 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
158 |
159 | inverted_mask = 1.0 - expanded_mask
160 |
161 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
162 |
163 |
164 | def _prepare_4d_causal_attention_mask(
165 | attention_mask: Optional[torch.Tensor],
166 | input_shape: Union[torch.Size, Tuple, List],
167 | inputs_embeds: torch.Tensor,
168 | past_key_values_length: int,
169 | sliding_window: Optional[int] = None,
170 | ):
171 | """
172 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
173 | `(batch_size, key_value_length)`
174 |
175 | Args:
176 | attention_mask (`torch.Tensor` or `None`):
177 | A 2D attention mask of shape `(batch_size, key_value_length)`
178 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
179 | The input shape should be a tuple that defines `(batch_size, query_length)`.
180 | inputs_embeds (`torch.Tensor`):
181 | The embedded inputs as a torch Tensor.
182 | past_key_values_length (`int`):
183 | The length of the key value cache.
184 | sliding_window (`int`, *optional*):
185 | If the model uses windowed attention, a sliding window should be passed.
186 | """
187 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
188 |
189 | key_value_length = input_shape[-1] + past_key_values_length
190 |
191 | # 4d mask is passed through the layers
192 | if attention_mask is not None:
193 | attention_mask = attn_mask_converter.to_4d(
194 | attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
195 | )
196 | else:
197 | attention_mask = attn_mask_converter.to_causal_4d(
198 | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
199 | )
200 |
201 | return attention_mask
202 |
203 |
204 | def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
205 | """
206 | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
207 | `(batch_size, key_value_length)`
208 |
209 | Args:
210 | mask (`torch.Tensor` or `None`):
211 | A 2D attention mask of shape `(batch_size, key_value_length)`
212 | dtype (`torch.dtype`):
213 | The torch dtype the created mask shall have.
214 | tgt_len (`int`):
215 | The target length or query length the created mask shall have.
216 | """
217 | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
218 |
219 |
220 | def _create_4d_causal_attention_mask(
221 | input_shape: Union[torch.Size, Tuple, List],
222 | dtype: torch.dtype,
223 | device: torch.device,
224 | past_key_values_length: int = 0,
225 | sliding_window: Optional[int] = None,
226 | ):
227 | """
228 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
229 |
230 | Args:
231 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
232 | The input shape should be a tuple that defines `(batch_size, query_length)`.
233 | dtype (`torch.dtype`):
234 | The torch dtype the created mask shall have.
235 | device (`int`):
236 | The torch device the created mask shall have.
237 | sliding_window (`int`, *optional*):
238 | If the model uses windowed attention, a sliding window should be passed.
239 | """
240 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
241 |
242 | key_value_length = past_key_values_length + input_shape[-1]
243 | attention_mask = attn_mask_converter.to_causal_4d(
244 | input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
245 | )
246 |
247 | return attention_mask
--------------------------------------------------------------------------------
/mplug_owl2/model/multiway.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.utils.checkpoint
4 | from torch import nn
5 |
6 |
7 | class MultiwayNetwork(nn.Module):
8 |
9 | def __init__(self, module_provider, num_multiway=2, out_features=None):
10 | super(MultiwayNetwork, self).__init__()
11 |
12 | self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
13 | self.out_features=out_features
14 | def forward(self, hidden_states, multiway_indices):
15 |
16 | if len(self.multiway) == 1:
17 | return self.multiway[0](hidden_states)
18 | if self.out_features:
19 | output_hidden_states = torch.empty(
20 | hidden_states.size(0), hidden_states.size(1), self.out_features,
21 | dtype=hidden_states.dtype
22 | ).to(hidden_states.device)
23 | else:
24 | output_hidden_states = torch.empty_like(hidden_states)
25 | for idx, subway in enumerate(self.multiway):
26 | local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
27 | hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
28 | if hidden.numel():
29 | output = subway(hidden)
30 | if isinstance(output, tuple):
31 | output = output[0]
32 | output = output.squeeze(1)
33 | output_hidden_states[local_indices] = output
34 |
35 | return output_hidden_states.contiguous()
--------------------------------------------------------------------------------
/mplug_owl2/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'mplug_owl2' in config and 'mplug_owl2' not in cfg.model_type:
7 | assert cfg.model_type == 'mplug_owl2'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "mplug_owl2")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
--------------------------------------------------------------------------------
/mplug_owl2/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/serve/__init__.py
--------------------------------------------------------------------------------
/mplug_owl2/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5 | from mplug_owl2.conversation import conv_templates, SeparatorStyle
6 | from mplug_owl2.model.builder import load_pretrained_model
7 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8 |
9 | from PIL import Image
10 |
11 | import requests
12 | from PIL import Image
13 | from io import BytesIO
14 | from transformers import TextStreamer
15 |
16 |
17 | def disable_torch_init():
18 | """
19 | Disable the redundant torch default initialization to accelerate model creation.
20 | """
21 | import torch
22 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
23 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
24 |
25 |
26 | def load_image(image_file):
27 | if image_file.startswith('http://') or image_file.startswith('https://'):
28 | response = requests.get(image_file)
29 | image = Image.open(BytesIO(response.content)).convert('RGB')
30 | else:
31 | image = Image.open(image_file).convert('RGB')
32 | return image
33 |
34 |
35 | def main(args):
36 | # Model
37 | disable_torch_init()
38 |
39 | model_name = get_model_name_from_path(args.model_path)
40 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
41 |
42 | conv_mode = "mplug_owl2"
43 |
44 | if args.conv_mode is not None and conv_mode != args.conv_mode:
45 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
46 | else:
47 | args.conv_mode = conv_mode
48 |
49 | conv = conv_templates[args.conv_mode].copy()
50 | roles = conv.roles
51 |
52 | image = load_image(args.image_file)
53 | # Similar operation in model_worker.py
54 | image_tensor = process_images([image], image_processor, args)
55 | if type(image_tensor) is list:
56 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
57 | else:
58 | image_tensor = image_tensor.to(model.device, dtype=torch.float16)
59 |
60 | while True:
61 | try:
62 | inp = input(f"{roles[0]}: ")
63 | except EOFError:
64 | inp = ""
65 | if not inp:
66 | print("exit...")
67 | break
68 |
69 | print(f"{roles[1]}: ", end="")
70 |
71 | if image is not None:
72 | # first message
73 | inp = DEFAULT_IMAGE_TOKEN + inp
74 | conv.append_message(conv.roles[0], inp)
75 | image = None
76 | else:
77 | # later messages
78 | conv.append_message(conv.roles[0], inp)
79 | conv.append_message(conv.roles[1], None)
80 | prompt = conv.get_prompt()
81 |
82 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
83 | stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
84 | keywords = [stop_str]
85 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
86 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
87 |
88 | with torch.inference_mode():
89 | output_ids = model.generate(
90 | input_ids,
91 | images=image_tensor,
92 | do_sample=True,
93 | temperature=args.temperature,
94 | max_new_tokens=args.max_new_tokens,
95 | streamer=streamer,
96 | use_cache=True,
97 | stopping_criteria=[stopping_criteria])
98 |
99 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
100 | conv.messages[-1][-1] = outputs
101 |
102 | if args.debug:
103 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
104 |
105 |
106 | if __name__ == "__main__":
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
109 | parser.add_argument("--model-base", type=str, default=None)
110 | parser.add_argument("--image-file", type=str, required=True)
111 | parser.add_argument("--device", type=str, default="cuda")
112 | parser.add_argument("--conv-mode", type=str, default=None)
113 | parser.add_argument("--temperature", type=float, default=0.2)
114 | parser.add_argument("--max-new-tokens", type=int, default=512)
115 | parser.add_argument("--load-8bit", action="store_true")
116 | parser.add_argument("--load-4bit", action="store_true")
117 | parser.add_argument("--debug", action="store_true")
118 | parser.add_argument("--image-aspect-ratio", type=str, default='pad')
119 | args = parser.parse_args()
120 | main(args)
--------------------------------------------------------------------------------
/mplug_owl2/serve/controller.py:
--------------------------------------------------------------------------------
1 | """
2 | A controller manages distributed workers.
3 | It sends worker addresses to clients.
4 | """
5 | import argparse
6 | import asyncio
7 | import dataclasses
8 | from enum import Enum, auto
9 | import json
10 | import logging
11 | import time
12 | from typing import List, Union
13 | import threading
14 |
15 | from fastapi import FastAPI, Request
16 | from fastapi.responses import StreamingResponse
17 | import numpy as np
18 | import requests
19 | import uvicorn
20 |
21 | from mplug_owl2.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22 | from mplug_owl2.utils import build_logger, server_error_msg
23 |
24 |
25 | logger = build_logger("controller", "controller.log")
26 |
27 |
28 | class DispatchMethod(Enum):
29 | LOTTERY = auto()
30 | SHORTEST_QUEUE = auto()
31 |
32 | @classmethod
33 | def from_str(cls, name):
34 | if name == "lottery":
35 | return cls.LOTTERY
36 | elif name == "shortest_queue":
37 | return cls.SHORTEST_QUEUE
38 | else:
39 | raise ValueError(f"Invalid dispatch method")
40 |
41 |
42 | @dataclasses.dataclass
43 | class WorkerInfo:
44 | model_names: List[str]
45 | speed: int
46 | queue_length: int
47 | check_heart_beat: bool
48 | last_heart_beat: str
49 |
50 |
51 | def heart_beat_controller(controller):
52 | while True:
53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54 | controller.remove_stable_workers_by_expiration()
55 |
56 |
57 | class Controller:
58 | def __init__(self, dispatch_method: str):
59 | # Dict[str -> WorkerInfo]
60 | self.worker_info = {}
61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62 |
63 | self.heart_beat_thread = threading.Thread(
64 | target=heart_beat_controller, args=(self,))
65 | self.heart_beat_thread.start()
66 |
67 | logger.info("Init controller")
68 |
69 | def register_worker(self, worker_name: str, check_heart_beat: bool,
70 | worker_status: dict):
71 | if worker_name not in self.worker_info:
72 | logger.info(f"Register a new worker: {worker_name}")
73 | else:
74 | logger.info(f"Register an existing worker: {worker_name}")
75 |
76 | if not worker_status:
77 | worker_status = self.get_worker_status(worker_name)
78 | if not worker_status:
79 | return False
80 |
81 | self.worker_info[worker_name] = WorkerInfo(
82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83 | check_heart_beat, time.time())
84 |
85 | logger.info(f"Register done: {worker_name}, {worker_status}")
86 | return True
87 |
88 | def get_worker_status(self, worker_name: str):
89 | try:
90 | r = requests.post(worker_name + "/worker_get_status", timeout=5)
91 | except requests.exceptions.RequestException as e:
92 | logger.error(f"Get status fails: {worker_name}, {e}")
93 | return None
94 |
95 | if r.status_code != 200:
96 | logger.error(f"Get status fails: {worker_name}, {r}")
97 | return None
98 |
99 | return r.json()
100 |
101 | def remove_worker(self, worker_name: str):
102 | del self.worker_info[worker_name]
103 |
104 | def refresh_all_workers(self):
105 | old_info = dict(self.worker_info)
106 | self.worker_info = {}
107 |
108 | for w_name, w_info in old_info.items():
109 | if not self.register_worker(w_name, w_info.check_heart_beat, None):
110 | logger.info(f"Remove stale worker: {w_name}")
111 |
112 | def list_models(self):
113 | model_names = set()
114 |
115 | for w_name, w_info in self.worker_info.items():
116 | model_names.update(w_info.model_names)
117 |
118 | return list(model_names)
119 |
120 | def get_worker_address(self, model_name: str):
121 | if self.dispatch_method == DispatchMethod.LOTTERY:
122 | worker_names = []
123 | worker_speeds = []
124 | for w_name, w_info in self.worker_info.items():
125 | if model_name in w_info.model_names:
126 | worker_names.append(w_name)
127 | worker_speeds.append(w_info.speed)
128 | worker_speeds = np.array(worker_speeds, dtype=np.float32)
129 | norm = np.sum(worker_speeds)
130 | if norm < 1e-4:
131 | return ""
132 | worker_speeds = worker_speeds / norm
133 | if True: # Directly return address
134 | pt = np.random.choice(np.arange(len(worker_names)),
135 | p=worker_speeds)
136 | worker_name = worker_names[pt]
137 | return worker_name
138 |
139 | # Check status before returning
140 | while True:
141 | pt = np.random.choice(np.arange(len(worker_names)),
142 | p=worker_speeds)
143 | worker_name = worker_names[pt]
144 |
145 | if self.get_worker_status(worker_name):
146 | break
147 | else:
148 | self.remove_worker(worker_name)
149 | worker_speeds[pt] = 0
150 | norm = np.sum(worker_speeds)
151 | if norm < 1e-4:
152 | return ""
153 | worker_speeds = worker_speeds / norm
154 | continue
155 | return worker_name
156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157 | worker_names = []
158 | worker_qlen = []
159 | for w_name, w_info in self.worker_info.items():
160 | if model_name in w_info.model_names:
161 | worker_names.append(w_name)
162 | worker_qlen.append(w_info.queue_length / w_info.speed)
163 | if len(worker_names) == 0:
164 | return ""
165 | min_index = np.argmin(worker_qlen)
166 | w_name = worker_names[min_index]
167 | self.worker_info[w_name].queue_length += 1
168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169 | return w_name
170 | else:
171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172 |
173 | def receive_heart_beat(self, worker_name: str, queue_length: int):
174 | if worker_name not in self.worker_info:
175 | logger.info(f"Receive unknown heart beat. {worker_name}")
176 | return False
177 |
178 | self.worker_info[worker_name].queue_length = queue_length
179 | self.worker_info[worker_name].last_heart_beat = time.time()
180 | logger.info(f"Receive heart beat. {worker_name}")
181 | return True
182 |
183 | def remove_stable_workers_by_expiration(self):
184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185 | to_delete = []
186 | for worker_name, w_info in self.worker_info.items():
187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188 | to_delete.append(worker_name)
189 |
190 | for worker_name in to_delete:
191 | self.remove_worker(worker_name)
192 |
193 | def worker_api_generate_stream(self, params):
194 | worker_addr = self.get_worker_address(params["model"])
195 | if not worker_addr:
196 | logger.info(f"no worker: {params['model']}")
197 | ret = {
198 | "text": server_error_msg,
199 | "error_code": 2,
200 | }
201 | yield json.dumps(ret).encode() + b"\0"
202 |
203 | try:
204 | response = requests.post(worker_addr + "/worker_generate_stream",
205 | json=params, stream=True, timeout=5)
206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207 | if chunk:
208 | yield chunk + b"\0"
209 | except requests.exceptions.RequestException as e:
210 | logger.info(f"worker timeout: {worker_addr}")
211 | ret = {
212 | "text": server_error_msg,
213 | "error_code": 3,
214 | }
215 | yield json.dumps(ret).encode() + b"\0"
216 |
217 |
218 | # Let the controller act as a worker to achieve hierarchical
219 | # management. This can be used to connect isolated sub networks.
220 | def worker_api_get_status(self):
221 | model_names = set()
222 | speed = 0
223 | queue_length = 0
224 |
225 | for w_name in self.worker_info:
226 | worker_status = self.get_worker_status(w_name)
227 | if worker_status is not None:
228 | model_names.update(worker_status["model_names"])
229 | speed += worker_status["speed"]
230 | queue_length += worker_status["queue_length"]
231 |
232 | return {
233 | "model_names": list(model_names),
234 | "speed": speed,
235 | "queue_length": queue_length,
236 | }
237 |
238 |
239 | app = FastAPI()
240 |
241 |
242 | @app.post("/register_worker")
243 | async def register_worker(request: Request):
244 | data = await request.json()
245 | controller.register_worker(
246 | data["worker_name"], data["check_heart_beat"],
247 | data.get("worker_status", None))
248 |
249 |
250 | @app.post("/refresh_all_workers")
251 | async def refresh_all_workers():
252 | models = controller.refresh_all_workers()
253 |
254 |
255 | @app.post("/list_models")
256 | async def list_models():
257 | models = controller.list_models()
258 | return {"models": models}
259 |
260 |
261 | @app.post("/get_worker_address")
262 | async def get_worker_address(request: Request):
263 | data = await request.json()
264 | addr = controller.get_worker_address(data["model"])
265 | return {"address": addr}
266 |
267 |
268 | @app.post("/receive_heart_beat")
269 | async def receive_heart_beat(request: Request):
270 | data = await request.json()
271 | exist = controller.receive_heart_beat(
272 | data["worker_name"], data["queue_length"])
273 | return {"exist": exist}
274 |
275 |
276 | @app.post("/worker_generate_stream")
277 | async def worker_api_generate_stream(request: Request):
278 | params = await request.json()
279 | generator = controller.worker_api_generate_stream(params)
280 | return StreamingResponse(generator)
281 |
282 |
283 | @app.post("/worker_get_status")
284 | async def worker_api_get_status(request: Request):
285 | return controller.worker_api_get_status()
286 |
287 |
288 | if __name__ == "__main__":
289 | parser = argparse.ArgumentParser()
290 | parser.add_argument("--host", type=str, default="localhost")
291 | parser.add_argument("--port", type=int, default=21001)
292 | parser.add_argument("--dispatch-method", type=str, choices=[
293 | "lottery", "shortest_queue"], default="shortest_queue")
294 | args = parser.parse_args()
295 | logger.info(f"args: {args}")
296 |
297 | controller = Controller(args.dispatch_method)
298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
--------------------------------------------------------------------------------
/mplug_owl2/serve/examples/Rebecca_(1939_poster)_Small.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/serve/examples/Rebecca_(1939_poster)_Small.jpeg
--------------------------------------------------------------------------------
/mplug_owl2/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d8ahazard/sd_smartprocess/61b0e747f11ac518b97571bdb482ea0d55cb4862/mplug_owl2/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/mplug_owl2/serve/register_workers.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
--------------------------------------------------------------------------------
/mplug_owl2/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | modality_indicators: torch.Tensor,
20 | attention_mask: Optional[torch.Tensor] = None,
21 | position_ids: Optional[torch.Tensor] = None,
22 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
23 | output_attentions: bool = False,
24 | use_cache: bool = False,
25 | padding_mask: bool = None,
26 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
27 | if output_attentions:
28 | warnings.warn(
29 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
30 | )
31 |
32 | bsz, q_len, _ = hidden_states.size()
33 |
34 | query_states = (
35 | self.q_proj(hidden_states)
36 | .view(bsz, q_len, self.num_heads, self.head_dim)
37 | .transpose(1, 2)
38 | )
39 | key_states = (
40 | self.k_proj(hidden_states, modality_indicators)
41 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
42 | .transpose(1, 2)
43 | )
44 | value_states = (
45 | self.v_proj(hidden_states, modality_indicators)
46 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
47 | .transpose(1, 2)
48 | ) # shape: (b, num_heads, s, head_dim)
49 |
50 | kv_seq_len = key_states.shape[-2]
51 | if past_key_value is not None:
52 | kv_seq_len += past_key_value[0].shape[-2]
53 |
54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55 | query_states, key_states = apply_rotary_pos_emb(
56 | query_states, key_states, cos, sin, position_ids
57 | )
58 |
59 | if past_key_value is not None:
60 | # reuse k, v
61 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
62 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
63 |
64 | past_key_value = (key_states, value_states) if use_cache else None
65 |
66 | # repeat k/v heads if n_kv_heads < n_heads
67 | key_states = repeat_kv(key_states, self.num_key_value_groups)
68 | value_states = repeat_kv(value_states, self.num_key_value_groups)
69 |
70 | # Transform the data into the format required by flash attention
71 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
72 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
73 | key_padding_mask = attention_mask
74 |
75 | if key_padding_mask is None:
76 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
77 | cu_q_lens = torch.arange(
78 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
79 | )
80 | max_s = q_len
81 | output = flash_attn_unpadded_qkvpacked_func(
82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
83 | )
84 | output = output.view(bsz, q_len, -1)
85 | else:
86 | qkv = qkv.reshape(bsz, q_len, -1)
87 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
88 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
89 | output_unpad = flash_attn_unpadded_qkvpacked_func(
90 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
91 | )
92 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
93 | output = pad_input(output_unpad, indices, bsz, q_len)
94 |
95 | return self.o_proj(output), None, past_key_value
96 |
97 |
98 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
99 | # requires the attention mask to be the same as the key_padding_mask
100 | def _prepare_decoder_attention_mask(
101 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
102 | ):
103 | # [bsz, seq_len]
104 | return attention_mask
105 |
106 |
107 | def replace_llama_attn_with_flash_attn():
108 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
109 | if cuda_major < 8:
110 | warnings.warn(
111 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
112 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
113 | )
114 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
115 | _prepare_decoder_attention_mask
116 | )
117 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
--------------------------------------------------------------------------------
/mplug_owl2/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from mplug_owl2.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from mplug_owl2.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
--------------------------------------------------------------------------------
/mplug_owl2/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from mplug_owl2.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True)
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
--------------------------------------------------------------------------------
/process_params.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Dict
3 |
4 | from extensions.sd_smartprocess.file_manager import ImageData
5 |
6 |
7 | @dataclass
8 | class ProcessParams:
9 | auto_save: bool = False
10 | blip_initial_prompt = "a caption for this image is: "
11 | booru_min_score: float = 0.75
12 | caption: bool = False
13 | captioners: Dict[str, bool] = field(default_factory=lambda: [])
14 | char_threshold: float = 0.5
15 | clip_append_artist: bool = False
16 | clip_append_flavor: bool = False
17 | clip_append_medium: bool = False
18 | clip_append_movement: bool = False
19 | clip_append_trending: bool = False
20 | clip_max_flavors: int = 3
21 | clip_use_v2: bool = False
22 | crop: bool = False
23 | crop_mode: str = "smart"
24 | do_backup: bool = False
25 | do_rename: bool = False
26 | dst: str = ""
27 | face_model: str = "Codeformers"
28 | flip: bool = False
29 | idefics2_prompt: str = "Describe this image in one detailed sentence, include the subject, location, style, and type of image."
30 | interrogation_prompt: str = "Describe this image in one detailed sentence."
31 | insert_subject: bool = False
32 | load_in_4bit: bool = False
33 | load_in_8bit: bool = False
34 | max_clip_tokens: float = 1.0
35 | max_size: int = 1024
36 | max_tokens: int = 75
37 | min_clip_tokens: float = 0.0
38 | new_caption: str = ""
39 | nl_captioners: Dict[str, bool] = field(default_factory=lambda: [])
40 | num_beams: int = 5
41 | pad: bool = False
42 | replace_blip_caption: bool = True
43 | replace_class: bool = False
44 | restore_faces: bool = False
45 | save_caption: bool = False
46 | save_image: bool = False
47 | src_files: List[ImageData] = field(default_factory=lambda: [])
48 | subject: str = ""
49 | subject_class: str = ""
50 | tags_to_ignore: List[str] = field(default_factory=lambda: [])
51 | threshold: float = 0.5
52 | txt_action: str = "ignore"
53 | upscale: bool = False
54 | upscale_max: int = 4096
55 | upscale_mode: str = "ratio"
56 | upscale_ratio: float = 2.0
57 | upscaler_1 = None
58 | upscaler_2 = None
59 | wd14_min_score: float = 0.75
60 | image_path = None
61 |
62 | def clip_params(self):
63 | return {
64 | "min_clip_tokens": self.min_clip_tokens,
65 | "max_clip_tokens": self.max_clip_tokens,
66 | "use_v2": self.clip_use_v2,
67 | "append_flavor": self.clip_append_flavor,
68 | "max_flavors": self.clip_max_flavors,
69 | "append_medium": self.clip_append_medium,
70 | "append_movement": self.clip_append_movement,
71 | "append_artist": self.clip_append_artist,
72 | "append_trending": self.clip_append_trending,
73 | "num_beams": self.num_beams,
74 | "clip_max_flavors": self.clip_max_flavors,
75 | "blip_initial_prompt": self.blip_initial_prompt
76 | }
77 |
78 | def pre_only(self):
79 | self.caption = False
80 | self.upscale = False
81 | self.restore_faces = False
82 |
83 | def cap_only(self):
84 | self.upscale = False
85 | self.restore_faces = False
86 | self.crop = False
87 | self.pad = False
88 |
89 | def post_only(self):
90 | self.caption = False
91 | self.crop = False
92 | self.pad = False
93 |
94 | @classmethod
95 | def from_dict(cls, d):
96 | instance = cls() # Get the singleton instance
97 | for k, v in d.items():
98 | k = k.replace("sp_", "") # Adjust the attribute name
99 | if k == "class":
100 | k = "subject_class"
101 | if hasattr(instance, k):
102 | setattr(instance, k, v)
103 | return instance
104 |
--------------------------------------------------------------------------------
/processors.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import random
3 | from typing import List, Any
4 |
5 | import torch
6 | from PIL import Image
7 | from tqdm import tqdm
8 |
9 | from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm
10 | from extensions.sd_smartprocess import super_resolution
11 | from extensions.sd_smartprocess.clipcrop import CropClip
12 | from extensions.sd_smartprocess.interrogators.clip_interrogator import CLIPInterrogator
13 | from extensions.sd_smartprocess.interrogators.booru_interrogator import BooruInterrogator
14 | from modules import shared
15 |
16 | # Base processor
17 | class Processor:
18 | def __init__(self):
19 | printm("Model loaded.")
20 |
21 | # Unload models
22 | def unload(self):
23 | if torch.has_cuda:
24 | torch.cuda.empty_cache()
25 | gc.collect()
26 | printm("Model unloaded.")
27 |
28 | # Process images
29 | def process(self, images: List[Image.Image]) -> List[Any]:
30 | raise Exception("Not Implemented")
31 |
32 | # CLIP Processing
33 | class ClipProcessor(Processor):
34 | def __init__(
35 | self,
36 | clip_use_v2,
37 | clip_append_artist,
38 | clip_append_medium,
39 | clip_append_movement,
40 | clip_append_flavor,
41 | clip_append_trending,
42 | num_beams,
43 | min_clip_tokens,
44 | max_clip_tokens,
45 | max_flavors
46 | ):
47 | self.description = "Processing CLIP"
48 | if shared.interrogator is not None:
49 | shared.interrogator.unload()
50 |
51 | self.max_flavors = max_flavors
52 | shared.state.textinfo = "Loading CLIP Model..."
53 | self.model = CLIPInterrogator(
54 | clip_use_v2,
55 | clip_append_artist,
56 | clip_append_medium,
57 | clip_append_movement,
58 | clip_append_flavor,
59 | clip_append_trending,
60 | num_beams,
61 | min_clip_tokens,
62 | max_clip_tokens
63 | )
64 | super().__init__()
65 |
66 | def process(self, images: List[Image.Image], short:bool=False) -> List[str]:
67 | output = []
68 | shared.state.job_count = len(images)
69 | shared.state.textinfo = f"{self.description}..."
70 | for img in tqdm(images, desc=self.description):
71 | short_caption = self.model.interrogate(img, short=short, max_flavors=self.max_flavors)
72 | output.append(short_caption)
73 | shared.state.current_image = img
74 | shared.state.job_no += 1
75 | return output
76 |
77 | def unload(self):
78 | if self.model.clip_model:
79 | del self.model.clip_model
80 | if self.model.blip_model:
81 | del self.model.blip_model
82 | super().unload()
83 |
84 | # Danbooru Processing
85 | class BooruProcessor(Processor):
86 | def __init__(self, min_score: float):
87 | self.description = "Processing Danbooru"
88 | shared.state.textinfo = "Loading DeepDanbooru Model..."
89 | self.model = BooruInterrogator()
90 | self.min_score = min_score
91 | super().__init__()
92 |
93 | def process(self, images: List[Image.Image]) -> List[List[str]]:
94 | output = []
95 | shared.state.job_count = len(images)
96 | shared.state.textinfo = f"{self.description}..."
97 | for img in tqdm(images, desc=self.description):
98 | out_tags = []
99 | tags = self.model.interrogate(img)
100 | for tag in sorted(tags, key=tags.get, reverse=True):
101 | if tags[tag] >= self.min_score:
102 | out_tags.append(tag)
103 | output.append(out_tags)
104 | shared.state.job_count += 1
105 |
106 | def unload(self):
107 | self.model.unload()
108 | super().unload()
109 |
110 | # WD14 Processing
111 |
112 | # Crop Processing
113 | class CropProcessor(Processor):
114 | def __init__(self, subject_class: str, pad: bool, crop: bool):
115 | self.description = "Cropping"
116 | if crop:
117 | shared.state.textinfo = "Loading CROP Model..."
118 | self.model = CropClip() if crop else None
119 | self.subject_class = subject_class
120 | self.pad = pad
121 | self.crop = crop
122 | super().__init__()
123 |
124 | def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]:
125 | output = []
126 | shared.state.job_count = len(images)
127 | shared.state.textinfo = f"{self.description}..."
128 | for img, caption in tqdm(zip(images, captions), desc=self.description):
129 | cropped = self._process_img(img, caption)
130 | output.append(cropped)
131 | shared.state.job_no += 1
132 | return output
133 |
134 |
135 | def _process_img(self, img, short_caption):
136 | if self.subject_class is not None and self.subject_class != "":
137 | short_caption = self.subject_class
138 |
139 | src_ratio = img.width / img.height
140 |
141 | # Pad image before cropping?
142 | if src_ratio != 1 and self.pad:
143 | if img.width > img.height:
144 | pad_width = img.width
145 | pad_height = img.width
146 | else:
147 | pad_width = img.height
148 | pad_height = img.height
149 | res = Image.new("RGB", (pad_width, pad_height))
150 | res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
151 | img = res
152 |
153 | if self.crop:
154 | # Do the actual crop clip
155 | im_data = self.model.get_center(img, prompt=short_caption)
156 | crop_width = im_data[1] - im_data[0]
157 | center_x = im_data[0] + (crop_width / 2)
158 | crop_height = im_data[3] - im_data[2]
159 | center_y = im_data[2] + (crop_height / 2)
160 | crop_ratio = crop_width / crop_height
161 | dest_ratio = 1
162 | tgt_width = crop_width
163 | tgt_height = crop_height
164 |
165 | if crop_ratio != dest_ratio:
166 | if crop_width > crop_height:
167 | tgt_height = crop_width / dest_ratio
168 | tgt_width = crop_width
169 | else:
170 | tgt_width = crop_height / dest_ratio
171 | tgt_height = crop_height
172 |
173 | # Reverse the above if dest is too big
174 | if tgt_width > img.width or tgt_height > img.height:
175 | if tgt_width > img.width:
176 | tgt_width = img.width
177 | tgt_height = tgt_width / dest_ratio
178 | else:
179 | tgt_height = img.height
180 | tgt_width = tgt_height / dest_ratio
181 |
182 | tgt_height = int(tgt_height)
183 | tgt_width = int(tgt_width)
184 | left = max(center_x - (tgt_width / 2), 0)
185 | right = min(center_x + (tgt_width / 2), img.width)
186 | top = max(center_y - (tgt_height / 2), 0)
187 | bottom = min(center_y + (tgt_height / 2), img.height)
188 | img = img.crop((left, top, right, bottom))
189 | return img
190 |
191 | def unload(self):
192 | if self.model is not None:
193 | self.model.unload()
194 | super().unload()
195 |
196 | # Upscale Processing
197 | class UpscaleProcessor(Processor):
198 | def __init__(self):
199 | self.description = "Upscaling"
200 | shared.state.textinfo = "Loading Stable-Diffusion Upscaling Model..."
201 | self.sampler, self.model = super_resolution.initialize_model()
202 | super().__init__()
203 |
204 | def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]:
205 | output = []
206 | shared.state.job_count = len(images)
207 | shared.state.textinfo = f"{self.description}..."
208 | for img, caption in tqdm(zip(images, captions), desc=self.description):
209 | seed = int(random.randrange(2147483647))
210 | img = super_resolution.predict(self.sampler, img, caption, 75, 1, 10, seed, 0, 20)
211 | output.append(img)
212 | shared.state.job_no += 1
213 | return output
214 |
215 | def unload(self):
216 | del self.sampler
217 | del self.model
218 | super().unload()
219 |
220 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipython==8.18.0
2 | spandrel>=0.3.1
3 | seaborn==0.13.2
4 | tensorflow
5 | open_clip_torch
6 | icecream==2.1.3
7 | chardet
8 | cchardet
9 | peft
10 | flask
11 | ruamel.yaml
12 | markdown2
13 | sconf
14 | tensorboardX
15 | transformers>=4.40.0
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | #sp_src {
2 | width: calc(100% - 84px);
3 | min-width: calc(100% - 84px);
4 | }
5 |
6 | #sp_clear_src, #sp_load_src {
7 | margin-top: 20px;
8 | height: 36px;
9 | }
10 |
11 | #sp_top_btn_row {
12 | padding-top: 20px;
13 | }
14 |
15 | #sp_clear_src, #sp_load_src {
16 | flex: none;
17 | max-width: 42px;
18 | min-width: 42px !important;
19 | height: 42px;
20 | line-height: 1em;
21 | border-radius: 0.5em;
22 | }
23 |
24 | #sp_backup_files {
25 | max-width: 120px !important;
26 | padding-top: 5px !important;
27 | margin-left: 15px;
28 | margin-right: 15px;
29 | min-width: 120px !important;
30 | }
31 |
32 | #sp_file_options {
33 | margin-left: 25%;
34 | }
35 |
36 | .inc_exc_btn {
37 | margin-top: 20px !important;
38 | }
39 |
40 | #sp_progressbar.hide-container, #sp_progress.hide-container, #sp_error.hide-container {
41 | min-height: 0;
42 | display: none;
43 | }
44 |
45 | .short_check {
46 | min-width: 120px !important;
47 | max-width: 120px !important;
48 | }
--------------------------------------------------------------------------------
/super_resolution.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os.path
3 |
4 | import safetensors.torch
5 | import torch
6 | import numpy as np
7 | from PIL import Image
8 | from omegaconf import OmegaConf
9 | from einops import repeat, rearrange
10 | from pytorch_lightning import seed_everything
11 |
12 | from ldm.models.diffusion.ddim import DDIMSampler
13 | from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion
14 | from ldm.util import exists, instantiate_from_config
15 |
16 | from modules import shared, modelloader, sd_hijack
17 |
18 | torch.set_grad_enabled(False)
19 |
20 |
21 | def initialize_model():
22 | model_dir = os.path.join(shared.models_path, "Upscale")
23 | if not os.path.exists(model_dir):
24 | os.makedirs(model_dir)
25 |
26 | model_name = os.path.join(model_dir, "x4-upscaler-ema.ckpt")
27 |
28 | model_url = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.ckpt"
29 |
30 | model_dir = os.path.join(shared.models_path, "Upscaler")
31 | model_file = modelloader.load_models(model_dir, model_url, None, '.ckpt', model_name)
32 | model_file = model_file[0]
33 | model_config = os.path.join(shared.script_path, "extensions", "sd_smartprocess", "configs", "upscaler.yaml")
34 |
35 | config = OmegaConf.load(model_config)
36 | model = instantiate_from_config(config.model)
37 | model_model = torch.load(model_file)
38 | model.load_state_dict(model_model["state_dict"], strict=False)
39 | model.half()
40 |
41 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
42 | model = model.to(device)
43 | sd_hijack.model_hijack.hijack(model) # apply optimization
44 |
45 | if shared.cmd_opts.opt_channelslast:
46 | model = model.to(memory_format=torch.channels_last)
47 |
48 | model.eval()
49 | sampler = DDIMSampler(model)
50 | return sampler, model
51 |
52 |
53 | def make_batch_sd(
54 | image,
55 | txt,
56 | device,
57 | num_samples=1,
58 | ):
59 | image = np.array(image.convert("RGB"))
60 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
61 | batch = {
62 | "lr": rearrange(image, 'h w c -> 1 c h w'),
63 | "txt": num_samples * [txt],
64 | }
65 | batch["lr"] = repeat(batch["lr"].to(device=device),
66 | "1 ... -> n ...", n=num_samples)
67 | return batch
68 |
69 |
70 | def make_noise_augmentation(model, batch, noise_level=None):
71 | x_low = batch[model.low_scale_key]
72 | x_low = x_low.to(memory_format=torch.contiguous_format).float()
73 | x_aug, noise_level = model.low_scale_model(x_low, noise_level)
74 | return x_aug, noise_level
75 |
76 |
77 | def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
78 | device = torch.device(
79 | "cuda") if torch.cuda.is_available() else torch.device("cpu")
80 | model = sampler.model
81 | seed_everything(seed)
82 | prng = np.random.RandomState(seed)
83 | start_code = prng.randn(num_samples, model.channels, h, w)
84 | start_code = torch.from_numpy(start_code).to(
85 | device=device, dtype=torch.float32)
86 |
87 | with torch.no_grad(),\
88 | torch.autocast("cuda"):
89 | batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
90 | print(f"Batch: {batch}")
91 | c = model.cond_stage_model.encode(batch["txt"])
92 | c_cat = list()
93 | if isinstance(model, LatentUpscaleFinetuneDiffusion):
94 | for ck in model.concat_keys:
95 | cc = batch[ck]
96 | if exists(model.reshuffle_patch_size):
97 | assert isinstance(model.reshuffle_patch_size, int)
98 | cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
99 | p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size)
100 | c_cat.append(cc)
101 | c_cat = torch.cat(c_cat, dim=1)
102 | # cond
103 | cond = {"c_concat": [c_cat], "c_crossattn": [c]}
104 | # uncond cond
105 | uc_cross = model.get_unconditional_conditioning(num_samples, "")
106 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
107 | elif isinstance(model, LatentUpscaleDiffusion):
108 | x_augment, noise_level = make_noise_augmentation(
109 | model, batch, noise_level)
110 | cond = {"c_concat": [x_augment],
111 | "c_crossattn": [c], "c_adm": noise_level}
112 | # uncond cond
113 | uc_cross = model.get_unconditional_conditioning(num_samples, "")
114 | uc_full = {"c_concat": [x_augment], "c_crossattn": [
115 | uc_cross], "c_adm": noise_level}
116 | else:
117 | raise NotImplementedError()
118 |
119 | shape = [model.channels, h, w]
120 | samples, intermediates = sampler.sample(
121 | steps,
122 | num_samples,
123 | shape,
124 | cond,
125 | verbose=False,
126 | eta=eta,
127 | unconditional_guidance_scale=scale,
128 | unconditional_conditioning=uc_full,
129 | x_T=start_code,
130 | callback=callback
131 | )
132 | with torch.no_grad():
133 | x_samples_ddim = model.decode_first_stage(samples)
134 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
135 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
136 | return [Image.fromarray(img.astype(np.uint8)) for img in result]
137 |
138 |
139 | def pad_image(input_image):
140 | pad_w, pad_h = np.max(((2, 2), np.ceil(
141 | np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
142 | im_padded = Image.fromarray(
143 | np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
144 | return im_padded
145 |
146 |
147 | def predict(sampler, input_image, prompt, steps, num_samples, scale, seed, eta, noise_level):
148 | init_image = input_image.convert("RGB")
149 | image = pad_image(init_image) # resize to integer multiple of 32
150 | width, height = image.size
151 |
152 | noise_level = torch.Tensor(
153 | num_samples * [noise_level]).to(sampler.model.device).long()
154 | sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
155 | result = paint(
156 | sampler=sampler,
157 | image=image,
158 | prompt=prompt,
159 | seed=seed,
160 | scale=scale,
161 | h=height, w=width, steps=steps,
162 | num_samples=num_samples,
163 | callback=None,
164 | noise_level=noise_level
165 | )
166 | return result
167 |
168 | def super_resolution(self, images, steps=50, target_scale=2, half_attention=False):
169 | model = self.load_model_from_config(half_attention)
170 | gc.collect()
171 | if torch.cuda.is_available:
172 | torch.cuda.empty_cache()
173 |
174 | # Run settings
175 | diffusion_steps = int(steps)
176 | eta = 1.0
177 | output = []
178 | for image in images:
179 | im_og = image
180 | width_og, height_og = im_og.size
181 | # If we can adjust the max upscale size, then the 4 below should be our variable
182 | down_sample_rate = target_scale / 4
183 | wd = width_og * down_sample_rate
184 | hd = height_og * down_sample_rate
185 | width_downsampled_pre = int(np.ceil(wd))
186 | height_downsampled_pre = int(np.ceil(hd))
187 |
188 | if down_sample_rate != 1:
189 | print(
190 | f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
191 | im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
192 | else:
193 | print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
194 |
195 | # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
196 | pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
197 | im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
198 |
199 | logs = self.run(model["model"], im_padded, diffusion_steps, eta)
200 |
201 | sample = logs["sample"]
202 | sample = sample.detach().cpu()
203 | sample = torch.clamp(sample, -1., 1.)
204 | sample = (sample + 1.) / 2. * 255
205 | sample = sample.numpy().astype(np.uint8)
206 | sample = np.transpose(sample, (0, 2, 3, 1))
207 | a = Image.fromarray(sample[0])
208 |
209 | # remove padding
210 | a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
211 | output.append(a)
212 |
213 | del model
214 | gc.collect()
215 | if torch.cuda.is_available:
216 | torch.cuda.empty_cache()
217 |
218 | return output
219 |
220 |
221 |
222 |
--------------------------------------------------------------------------------
/upscalers/spandrel/spandrel_srformer_model.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | from extensions.sd_smartprocess.upscalers.spandrel.spandrel_upscaler_base import SpandrelUpscaler
4 | from modules import modelloader
5 | from modules.upscaler import UpscalerData
6 |
7 | model_urls = [
8 | "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth",
9 |
10 | ]
11 |
12 | models = {
13 | "SRformer 4xFF": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4xFrankendata_FullDegradation_g.pth",
14 | "SRFormer 4xNomos8kSC": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4xNomos8kSC_SRFormer.pth",
15 | "SRFormer FP": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/FrankendataPretrainer_SRFormer_g.pth",
16 | "SwinIR x8": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth",
17 | "4xLDSIR": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4xLSDIR_DAT.pth",
18 | "ClearRealityV1 4x": "https://github.com/d8ahazard/sd_smartprocess/releases/download/1.0.0/4x-ClearRealityV1_SPAN.pth"
19 | }
20 |
21 |
22 | class SpandrelSRFormerModel(SpandrelUpscaler):
23 | scale = 4
24 | name = "SRFormer"
25 |
26 | def __init__(self, create_dirs=False):
27 | super().__init__(create_dirs)
28 | self.name = "SRFormer"
29 | self.scale = 4
30 | user_models = self.find_models(ext_filter=[".pth"])
31 | self.scalers = []
32 | added_models = []
33 | for file in user_models:
34 | model_name = os.path.basename(file)
35 | display_name = modelloader.friendly_name(file)
36 | for pre_name, model_url in models.items():
37 | if model_name in model_url:
38 | display_name = pre_name
39 | self.scalers.append(UpscalerData(display_name, file, self))
40 | added_models.append(display_name)
41 | break
42 | if display_name not in added_models:
43 | self.scalers.append(UpscalerData(display_name, file, self))
44 | added_models.append(display_name)
45 | for model_name, model_url in models.items():
46 | if model_name not in added_models:
47 | file_name = os.path.basename(model_url)
48 | model_path = modelloader.load_file_from_url(model_url, model_dir=self.model_path, file_name=file_name)
49 | self.scalers.append(UpscalerData(model_name, model_path, self))
50 |
--------------------------------------------------------------------------------
/upscalers/spandrel/spandrel_upscaler_base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import os
5 | import re
6 | import time
7 | import traceback
8 | import zipfile
9 | from abc import abstractmethod
10 | from pathlib import Path
11 | from urllib.request import urlretrieve
12 |
13 | import PIL
14 | import cv2
15 | import numpy as np
16 | import torch
17 | from PIL import Image
18 | from spandrel import (
19 | ImageModelDescriptor,
20 | ModelDescriptor, ModelLoader,
21 | )
22 |
23 | from modules import shared
24 | from modules.upscaler import Upscaler, UpscalerData, NEAREST
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | def convert_google_drive_link(url: str) -> str:
30 | pattern = re.compile(
31 | r"^https://drive.google.com/file/d/([a-zA-Z0-9_\-]+)/view(?:\?.*)?$"
32 | )
33 | m = pattern.match(url)
34 | if not m:
35 | return url
36 | file_id = m.group(1)
37 | return "https://drive.google.com/uc?export=download&confirm=1&id=" + file_id
38 |
39 |
40 | def download_file(url: str, filename: Path | str) -> None:
41 | filename = Path(filename)
42 | filename.parent.mkdir(exist_ok=True)
43 |
44 | url = convert_google_drive_link(url)
45 |
46 | temp_filename = filename.with_suffix(f".part-{int(time.time())}")
47 |
48 | try:
49 | logger.info("Downloading %s to %s", url, filename)
50 | path, _ = urlretrieve(url, filename=temp_filename)
51 | temp_filename.rename(filename)
52 | finally:
53 | try:
54 | temp_filename.unlink()
55 | except FileNotFoundError:
56 | pass
57 |
58 |
59 | def extract_file_from_zip(
60 | zip_path: Path | str,
61 | rel_model_path: str,
62 | filename: Path | str,
63 | ):
64 | filename = Path(filename)
65 | filename.parent.mkdir(exist_ok=True)
66 |
67 | with zipfile.ZipFile(zip_path, "r") as zip_ref:
68 | with open(filename, "wb") as f:
69 | f.write(zip_ref.read(rel_model_path))
70 |
71 |
72 | def image_to_tensor(img: np.ndarray, device: str, half) -> torch.Tensor:
73 | img = img.astype(np.float32) / 255.0
74 | if img.ndim == 2:
75 | img = np.expand_dims(img, axis=2)
76 | if img.shape[2] == 1:
77 | pass
78 | else:
79 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
80 | img = np.transpose(img, (2, 0, 1))
81 | tensor = torch.from_numpy(img).to(device)
82 | if half is not None:
83 | tensor = tensor.to(half)
84 | return tensor.unsqueeze(0)
85 |
86 |
87 | def tensor_to_image(tensor: torch.Tensor) -> np.ndarray:
88 | image = tensor.cpu().squeeze().numpy()
89 | image = np.transpose(image, (1, 2, 0))
90 | image = np.clip((image * 255.0).round(), 0, 255)
91 | image = image.astype(np.uint8)
92 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
93 | return image
94 |
95 |
96 | def image_inference_tensor(
97 | model: ImageModelDescriptor, tensor: torch.Tensor
98 | ) -> torch.Tensor:
99 | model.eval()
100 | with torch.no_grad():
101 | return model(tensor)
102 |
103 |
104 | def image_inference(model: ImageModelDescriptor, image: np.ndarray, device: str, half: str) -> np.ndarray:
105 | return tensor_to_image(image_inference_tensor(model, image_to_tensor(image, device, half)))
106 |
107 |
108 | def get_h_w_c(image: np.ndarray) -> tuple[int, int, int]:
109 | if len(image.shape) == 2:
110 | return image.shape[0], image.shape[1], 1
111 | return image.shape[0], image.shape[1], image.shape[2]
112 |
113 |
114 | class SpandrelUpscaler(Upscaler):
115 | model_url = ""
116 | model_type = ""
117 | model_file = ""
118 | scale = 4
119 |
120 | def __init__(self, create_dirs=False):
121 | super().__init__(create_dirs)
122 | self.name = "Spandrel"
123 | self.scale = 1
124 | self.scalers = []
125 |
126 | def do_upscale(self, img: PIL.Image, selected_model: str):
127 | self.load_model(selected_model)
128 | return self.internal_upscale(img)
129 |
130 | def load_model(self, path: str):
131 | self.model = ModelLoader().load_from_file(path)
132 | print(f"Model size reqs: {self.model.size_requirements}")
133 | device = "cuda" if torch.cuda.is_available() else "cpu"
134 | self.model.to(device)
135 | if self.model.supports_half:
136 | self.model.to(torch.half)
137 | self.model.eval()
138 |
139 | def preprocess(self, image: Image) -> Image:
140 | square = self.model.size_requirements.square
141 | minimum = self.model.size_requirements.minimum
142 | multiple_of = self.model.size_requirements.multiple_of
143 | if square:
144 | # Pad the shorter side to make the image square
145 | size = max(image.width, image.height)
146 | new_image = Image.new("RGB", (size, size))
147 | new_image.paste(image, ((size - image.width) // 2, (size - image.height) // 2))
148 | image = new_image
149 | if minimum > 1:
150 | size = max(image.width, image.height)
151 | if size < minimum:
152 | new_width = int(minimum * image.width / size)
153 | new_height = int(minimum * image.height / size)
154 | image = image.resize((new_width, new_height), resample=NEAREST)
155 | if multiple_of > 1:
156 | new_width = int(multiple_of * image.width // multiple_of)
157 | new_height = int(multiple_of * image.height // multiple_of)
158 | image = image.resize((new_width, new_height), resample=NEAREST)
159 | return image
160 |
161 | def postprocess(self, image: Image, original_width: int, original_height: int) -> Image:
162 | square = self.model.size_requirements.square
163 | if square:
164 | original_aspect_ratio = original_width / original_height
165 | current_aspect_ratio = image.width / image.height
166 |
167 | if current_aspect_ratio > original_aspect_ratio:
168 | # Image is wider than the original, crop width
169 | new_width = int(original_aspect_ratio * image.height)
170 | left = (image.width - new_width) // 2
171 | image = image.crop((left, 0, left + new_width, image.height))
172 | elif current_aspect_ratio < original_aspect_ratio:
173 | # Image is taller than the original, crop height
174 | new_height = int(image.width / original_aspect_ratio)
175 | top = (image.height - new_height) // 2
176 | image = image.crop((0, top, image.width, top + new_height))
177 |
178 | return image
179 |
180 | def internal_upscale(self, image: Image):
181 | original_width = image.width
182 | original_height = image.height
183 | needs_preprocess = self.model.size_requirements.check(image.width, image.height)
184 | if not needs_preprocess:
185 | image = self.preprocess(image)
186 | # Convert image to cv2 format
187 | image = np.array(image)
188 | image_h, image_w, image_c = get_h_w_c(image)
189 |
190 | if self.model.input_channels == 1 and image_c == 3:
191 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
192 |
193 | device = "cuda" if torch.cuda.is_available() else "cpu"
194 |
195 | # If the image size is already greater than 2048, we'll likely OOM on GPU, so do it on CPU
196 | if image_h > 2048 or image_w > 2048:
197 | device = "cpu"
198 |
199 | try:
200 | self.model.to(device)
201 | half = None
202 | if self.model.supports_half:
203 | half = torch.half
204 | output = image_inference(self.model, image, device, half)
205 | # Convert output to PIL format
206 | output = Image.fromarray(output)
207 | if needs_preprocess:
208 | output = self.postprocess(output, original_width, original_height)
209 | return output
210 | except Exception as e:
211 | print(f"Failed to upscale image: {e}")
212 | traceback.print_exc()
213 | return Image.fromarray(image)
214 |
215 | def unload(self):
216 | try:
217 | del self.model
218 | except:
219 | pass
220 | self.model = None
221 |
222 | def load(self):
223 | pass
224 |
--------------------------------------------------------------------------------