├── .idea
├── .gitignore
├── .name
├── SUPIR.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── webResources.xml
├── CKPT_PTH.py
├── LICENSE
├── README.md
├── SUPIR
├── __init__.py
├── models
│ ├── SUPIR_model.py
│ └── __init__.py
├── modules
│ ├── SUPIR_v0.py
│ └── __init__.py
├── util.py
└── utils
│ ├── __init__.py
│ └── colorfix.py
├── assets
├── DemoGuide.png
├── framework.png
└── teaser.png
├── cog.yaml
├── gradio_demo.py
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── eval
│ ├── eval_gpt_review.py
│ ├── eval_gpt_review_bench.py
│ ├── eval_gpt_review_visual.py
│ ├── eval_pope.py
│ ├── eval_science_qa.py
│ ├── eval_science_qa_gpt4.py
│ ├── eval_science_qa_gpt4_requery.py
│ ├── eval_textvqa.py
│ ├── generate_webpage_data_from_table.py
│ ├── m4c_evaluator.py
│ ├── model_qa.py
│ ├── model_vqa.py
│ ├── model_vqa_loader.py
│ ├── model_vqa_mmbench.py
│ ├── model_vqa_science.py
│ ├── qa_baseline_gpt35.py
│ ├── run_llava.py
│ ├── summarize_gpt_review.py
│ ├── table
│ │ ├── answer
│ │ │ ├── answer_alpaca-13b.jsonl
│ │ │ ├── answer_bard.jsonl
│ │ │ ├── answer_gpt35.jsonl
│ │ │ ├── answer_llama-13b.jsonl
│ │ │ └── answer_vicuna-13b.jsonl
│ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ ├── model.jsonl
│ │ ├── prompt.jsonl
│ │ ├── question.jsonl
│ │ ├── results
│ │ │ ├── test_sqa_llava_13b_v0.json
│ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│ │ ├── review
│ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl
│ │ │ ├── review_bard_vicuna-13b.jsonl
│ │ │ ├── review_gpt35_vicuna-13b.jsonl
│ │ │ └── review_llama-13b_vicuna-13b.jsonl
│ │ ├── reviewer.jsonl
│ │ └── rule.json
│ └── webpage
│ │ ├── figures
│ │ ├── alpaca.png
│ │ ├── bard.jpg
│ │ ├── chatgpt.svg
│ │ ├── llama.jpg
│ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg
│ │ └── vicuna.jpeg
│ │ ├── index.html
│ │ ├── script.js
│ │ └── styles.css
├── llava_agent.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_llama.py
│ │ ├── llava_mpt.py
│ │ └── mpt
│ │ │ ├── adapt_tokenizer.py
│ │ │ ├── attention.py
│ │ │ ├── blocks.py
│ │ │ ├── configuration_mpt.py
│ │ │ ├── custom_embedding.py
│ │ │ ├── flash_attn_triton.py
│ │ │ ├── hf_prefixlm_converter.py
│ │ │ ├── meta_init_context.py
│ │ │ ├── modeling_mpt.py
│ │ │ ├── norm.py
│ │ │ └── param_init_fns.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── utils.py
├── serve
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── extreme_ironing.jpg
│ │ └── waterview.jpg
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ ├── register_worker.py
│ └── test_message.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llava_trainer.py
│ ├── train.py
│ └── train_mem.py
└── utils.py
├── options
└── SUPIR_v0.yaml
├── predict.py
├── requirements.txt
├── sgm
├── __init__.py
├── lr_scheduler.py
├── models
│ ├── __init__.py
│ ├── autoencoder.py
│ └── diffusion.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── autoencoding
│ │ ├── __init__.py
│ │ ├── losses
│ │ │ └── __init__.py
│ │ ├── lpips
│ │ │ ├── __init__.py
│ │ │ ├── loss
│ │ │ │ ├── .gitignore
│ │ │ │ ├── LICENSE
│ │ │ │ ├── __init__.py
│ │ │ │ └── lpips.py
│ │ │ ├── model
│ │ │ │ ├── LICENSE
│ │ │ │ ├── __init__.py
│ │ │ │ └── model.py
│ │ │ ├── util.py
│ │ │ └── vqperceptual.py
│ │ └── regularizers
│ │ │ └── __init__.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── denoiser.py
│ │ ├── denoiser_scaling.py
│ │ ├── denoiser_weighting.py
│ │ ├── discretizer.py
│ │ ├── guiders.py
│ │ ├── loss.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ ├── sampling.py
│ │ ├── sampling_utils.py
│ │ ├── sigma_sampling.py
│ │ ├── util.py
│ │ └── wrappers.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ └── encoders
│ │ ├── __init__.py
│ │ └── modules.py
└── util.py
└── test.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/.name:
--------------------------------------------------------------------------------
1 | Diff4R
--------------------------------------------------------------------------------
/.idea/SUPIR.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
28 |
29 |
30 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webResources.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/CKPT_PTH.py:
--------------------------------------------------------------------------------
1 | LLAVA_CLIP_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/clip-vit-large-patch14-336'
2 | LLAVA_MODEL_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/llava-v1.5-13b'
3 | SDXL_CLIP1_PATH = '/opt/data/private/AIGC_pretrain/clip-vit-large-patch14'
4 | SDXL_CLIP2_CACHE_DIR = '/opt/data/private/AIGC_pretrain' # put package 'models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k' here
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Fanghua-Yu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/SUPIR/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/SUPIR/__init__.py
--------------------------------------------------------------------------------
/SUPIR/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/SUPIR/models/__init__.py
--------------------------------------------------------------------------------
/SUPIR/modules/__init__.py:
--------------------------------------------------------------------------------
1 | SDXL_BASE_CHANNEL_DICT = {
2 | 'cond_output_channels': [320] * 4 + [640] * 3 + [1280] * 3,
3 | 'project_channels': [160] * 4 + [320] * 3 + [640] * 3,
4 | 'concat_channels': [320] * 2 + [640] * 3 + [1280] * 4 + [0]
5 | }
6 |
7 | SDXL_REFINE_CHANNEL_DICT = {
8 | 'cond_output_channels': [384] * 4 + [768] * 3 + [1536] * 6,
9 | 'project_channels': [192] * 4 + [384] * 3 + [768] * 6,
10 | 'concat_channels': [384] * 2 + [768] * 3 + [1536] * 7 + [0]
11 | }
--------------------------------------------------------------------------------
/SUPIR/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import cv2
5 | from PIL import Image
6 | from torch.nn.functional import interpolate
7 | from omegaconf import OmegaConf
8 | from sgm.util import instantiate_from_config
9 |
10 |
11 | def get_state_dict(d):
12 | return d.get('state_dict', d)
13 |
14 |
15 | def load_state_dict(ckpt_path, location='cpu'):
16 | _, extension = os.path.splitext(ckpt_path)
17 | if extension.lower() == ".safetensors":
18 | import safetensors.torch
19 | state_dict = safetensors.torch.load_file(ckpt_path, device=location)
20 | else:
21 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
22 | state_dict = get_state_dict(state_dict)
23 | print(f'Loaded state_dict from [{ckpt_path}]')
24 | return state_dict
25 |
26 |
27 | def create_model(config_path):
28 | config = OmegaConf.load(config_path)
29 | model = instantiate_from_config(config.model).cpu()
30 | print(f'Loaded model config from [{config_path}]')
31 | return model
32 |
33 |
34 | def create_SUPIR_model(config_path, SUPIR_sign=None):
35 | config = OmegaConf.load(config_path)
36 | model = instantiate_from_config(config.model).cpu()
37 | print(f'Loaded model config from [{config_path}]')
38 | if config.SDXL_CKPT is not None:
39 | model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False)
40 | if config.SUPIR_CKPT is not None:
41 | model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False)
42 | if SUPIR_sign is not None:
43 | assert SUPIR_sign in ['F', 'Q']
44 | if SUPIR_sign == 'F':
45 | model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False)
46 | elif SUPIR_sign == 'Q':
47 | model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False)
48 | return model
49 |
50 | def load_QF_ckpt(config_path):
51 | config = OmegaConf.load(config_path)
52 | ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location='cpu')
53 | ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location='cpu')
54 | return ckpt_Q, ckpt_F
55 |
56 |
57 | def PIL2Tensor(img, upsacle=1, min_size=1024):
58 | '''
59 | PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
60 | '''
61 | # size
62 | w, h = img.size
63 | w *= upsacle
64 | h *= upsacle
65 | w0, h0 = round(w), round(h)
66 | if min(w, h) < min_size:
67 | _upsacle = min_size / min(w, h)
68 | w *= _upsacle
69 | h *= _upsacle
70 | else:
71 | _upsacle = 1
72 | w = int(np.round(w / 64.0)) * 64
73 | h = int(np.round(h / 64.0)) * 64
74 | x = img.resize((w, h), Image.BICUBIC)
75 | x = np.array(x).round().clip(0, 255).astype(np.uint8)
76 | x = x / 255 * 2 - 1
77 | x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
78 | return x, h0, w0
79 |
80 |
81 | def Tensor2PIL(x, h0, w0):
82 | '''
83 | Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
84 | '''
85 | x = x.unsqueeze(0)
86 | x = interpolate(x, size=(h0, w0), mode='bicubic')
87 | x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
88 | return Image.fromarray(x)
89 |
90 |
91 | def HWC3(x):
92 | assert x.dtype == np.uint8
93 | if x.ndim == 2:
94 | x = x[:, :, None]
95 | assert x.ndim == 3
96 | H, W, C = x.shape
97 | assert C == 1 or C == 3 or C == 4
98 | if C == 3:
99 | return x
100 | if C == 1:
101 | return np.concatenate([x, x, x], axis=2)
102 | if C == 4:
103 | color = x[:, :, 0:3].astype(np.float32)
104 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
105 | y = color * alpha + 255.0 * (1.0 - alpha)
106 | y = y.clip(0, 255).astype(np.uint8)
107 | return y
108 |
109 |
110 | def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
111 | H, W, C = input_image.shape
112 | H = float(H)
113 | W = float(W)
114 | H *= upscale
115 | W *= upscale
116 | if min_size is not None:
117 | if min(H, W) < min_size:
118 | _upsacle = min_size / min(W, H)
119 | W *= _upsacle
120 | H *= _upsacle
121 | H = int(np.round(H / unit_resolution)) * unit_resolution
122 | W = int(np.round(W / unit_resolution)) * unit_resolution
123 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
124 | img = img.round().clip(0, 255).astype(np.uint8)
125 | return img
126 |
127 |
128 | def fix_resize(input_image, size=512, unit_resolution=64):
129 | H, W, C = input_image.shape
130 | H = float(H)
131 | W = float(W)
132 | upscale = size / min(H, W)
133 | H *= upscale
134 | W *= upscale
135 | H = int(np.round(H / unit_resolution)) * unit_resolution
136 | W = int(np.round(W / unit_resolution)) * unit_resolution
137 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
138 | img = img.round().clip(0, 255).astype(np.uint8)
139 | return img
140 |
141 |
142 |
143 | def Numpy2Tensor(img):
144 | '''
145 | np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
146 | '''
147 | # size
148 | img = np.array(img) / 255 * 2 - 1
149 | img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
150 | return img
151 |
152 |
153 | def Tensor2Numpy(x, h0=None, w0=None):
154 | '''
155 | Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
156 | '''
157 | if h0 is not None and w0 is not None:
158 | x = x.unsqueeze(0)
159 | x = interpolate(x, size=(h0, w0), mode='bicubic')
160 | x = x.squeeze(0)
161 | x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
162 | return x
163 |
164 |
165 | def convert_dtype(dtype_str):
166 | if dtype_str == 'fp32':
167 | return torch.float32
168 | elif dtype_str == 'fp16':
169 | return torch.float16
170 | elif dtype_str == 'bf16':
171 | return torch.bfloat16
172 | else:
173 | raise NotImplementedError
174 |
--------------------------------------------------------------------------------
/SUPIR/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/SUPIR/utils/__init__.py
--------------------------------------------------------------------------------
/SUPIR/utils/colorfix.py:
--------------------------------------------------------------------------------
1 | '''
2 | # --------------------------------------------------------------------------------
3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4 | # --------------------------------------------------------------------------------
5 | '''
6 |
7 | import torch
8 | from PIL import Image
9 | from torch import Tensor
10 | from torch.nn import functional as F
11 |
12 | from torchvision.transforms import ToTensor, ToPILImage
13 |
14 | def adain_color_fix(target: Image, source: Image):
15 | # Convert images to tensors
16 | to_tensor = ToTensor()
17 | target_tensor = to_tensor(target).unsqueeze(0)
18 | source_tensor = to_tensor(source).unsqueeze(0)
19 |
20 | # Apply adaptive instance normalization
21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22 |
23 | # Convert tensor back to image
24 | to_image = ToPILImage()
25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26 |
27 | return result_image
28 |
29 | def wavelet_color_fix(target: Image, source: Image):
30 | # Convert images to tensors
31 | to_tensor = ToTensor()
32 | target_tensor = to_tensor(target).unsqueeze(0)
33 | source_tensor = to_tensor(source).unsqueeze(0)
34 |
35 | # Apply wavelet reconstruction
36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37 |
38 | # Convert tensor back to image
39 | to_image = ToPILImage()
40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41 |
42 | return result_image
43 |
44 | def calc_mean_std(feat: Tensor, eps=1e-5):
45 | """Calculate mean and std for adaptive_instance_normalization.
46 | Args:
47 | feat (Tensor): 4D tensor.
48 | eps (float): A small value added to the variance to avoid
49 | divide-by-zero. Default: 1e-5.
50 | """
51 | size = feat.size()
52 | assert len(size) == 4, 'The input feature should be 4D tensor.'
53 | b, c = size[:2]
54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57 | return feat_mean, feat_std
58 |
59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60 | """Adaptive instance normalization.
61 | Adjust the reference features to have the similar color and illuminations
62 | as those in the degradate features.
63 | Args:
64 | content_feat (Tensor): The reference feature.
65 | style_feat (Tensor): The degradate features.
66 | """
67 | size = content_feat.size()
68 | style_mean, style_std = calc_mean_std(style_feat)
69 | content_mean, content_std = calc_mean_std(content_feat)
70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72 |
73 | def wavelet_blur(image: Tensor, radius: int):
74 | """
75 | Apply wavelet blur to the input tensor.
76 | """
77 | # input shape: (1, 3, H, W)
78 | # convolution kernel
79 | kernel_vals = [
80 | [0.0625, 0.125, 0.0625],
81 | [0.125, 0.25, 0.125],
82 | [0.0625, 0.125, 0.0625],
83 | ]
84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85 | # add channel dimensions to the kernel to make it a 4D tensor
86 | kernel = kernel[None, None]
87 | # repeat the kernel across all input channels
88 | kernel = kernel.repeat(3, 1, 1, 1)
89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90 | # apply convolution
91 | output = F.conv2d(image, kernel, groups=3, dilation=radius)
92 | return output
93 |
94 | def wavelet_decomposition(image: Tensor, levels=5):
95 | """
96 | Apply wavelet decomposition to the input tensor.
97 | This function only returns the low frequency & the high frequency.
98 | """
99 | high_freq = torch.zeros_like(image)
100 | for i in range(levels):
101 | radius = 2 ** i
102 | low_freq = wavelet_blur(image, radius)
103 | high_freq += (image - low_freq)
104 | image = low_freq
105 |
106 | return high_freq, low_freq
107 |
108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109 | """
110 | Apply wavelet decomposition, so that the content will have the same color as the style.
111 | """
112 | # calculate the wavelet decomposition of the content feature
113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114 | del content_low_freq
115 | # calculate the wavelet decomposition of the style feature
116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117 | del style_high_freq
118 | # reconstruct the content feature with the style's high frequency
119 | return content_high_freq + style_low_freq
120 |
121 |
--------------------------------------------------------------------------------
/assets/DemoGuide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/assets/DemoGuide.png
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/assets/framework.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/assets/teaser.png
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 | system_packages:
7 | - "libgl1-mesa-glx"
8 | - "libglib2.0-0"
9 | python_version: "3.11"
10 | python_packages:
11 | - sentencepiece==0.1.98
12 | - tokenizers==0.13.3
13 | - torch>=2.1.0
14 | - torchvision>=0.16.0
15 | - uvicorn==0.21.1
16 | - transformers==4.28.1
17 | - accelerate==0.18.0
18 | - scikit-learn==1.2.2
19 | - sentencepiece==0.1.98
20 | - einops==0.7.0
21 | - einops-exts==0.0.4
22 | - timm==0.9.8
23 | - openai-clip==1.0.1
24 | - kornia==0.6.9
25 | - matplotlib==3.7.1
26 | - ninja==1.11.1
27 | - omegaconf==2.3.0
28 | - open-clip-torch==2.17.1
29 | - opencv-python==4.7.0.72
30 | - pandas==2.0.1
31 | - Pillow==9.4.0
32 | - pytorch-lightning==2.1.2
33 | - PyYAML==6.0
34 | - scipy==1.12.0
35 | - tqdm==4.65.0
36 | - triton==2.1.0
37 | - webdataset==0.2.48
38 | - xformers>=0.0.20
39 | run:
40 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
41 | predict: "predict.py:Predictor"
42 |
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import tqdm
7 | import ray
8 | import time
9 |
10 | NUM_SECONDS_TO_SLEEP = 3
11 |
12 | @ray.remote(num_cpus=4)
13 | def get_eval(content: str, max_tokens: int):
14 | while True:
15 | try:
16 | response = openai.ChatCompletion.create(
17 | model='gpt-4',
18 | messages=[{
19 | 'role': 'system',
20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21 | }, {
22 | 'role': 'user',
23 | 'content': content,
24 | }],
25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
26 | max_tokens=max_tokens,
27 | )
28 | break
29 | except openai.error.RateLimitError:
30 | pass
31 | except Exception as e:
32 | print(e)
33 | time.sleep(NUM_SECONDS_TO_SLEEP)
34 |
35 | print('success!')
36 | return response['choices'][0]['message']['content']
37 |
38 |
39 | def parse_score(review):
40 | try:
41 | score_pair = review.split('\n')[0]
42 | score_pair = score_pair.replace(',', ' ')
43 | sp = score_pair.split(' ')
44 | if len(sp) == 2:
45 | return [float(sp[0]), float(sp[1])]
46 | else:
47 | print('error', review)
48 | return [-1, -1]
49 | except Exception as e:
50 | print(e)
51 | print('error', review)
52 | return [-1, -1]
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57 | parser.add_argument('-q', '--question')
58 | # parser.add_argument('-a', '--answer')
59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60 | parser.add_argument('-r', '--rule')
61 | parser.add_argument('-o', '--output')
62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63 | args = parser.parse_args()
64 |
65 | ray.init()
66 |
67 | f_q = open(os.path.expanduser(args.question))
68 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71 |
72 | review_file = open(f'{args.output}', 'w')
73 |
74 | js_list = []
75 | handles = []
76 | idx = 0
77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78 | # if idx == 1:
79 | # break
80 |
81 | ques = json.loads(ques_js)
82 | ans1 = json.loads(ans1_js)
83 | ans2 = json.loads(ans2_js)
84 |
85 | category = json.loads(ques_js)['category']
86 | if category in rule_dict:
87 | rule = rule_dict[category]
88 | else:
89 | rule = rule_dict['default']
90 | prompt = rule['prompt']
91 | role = rule['role']
92 | content = (f'[Question]\n{ques["text"]}\n\n'
93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95 | f'[System]\n{prompt}\n\n')
96 | js_list.append({
97 | 'id': idx+1,
98 | 'question_id': ques['question_id'],
99 | 'answer1_id': ans1['answer_id'],
100 | 'answer2_id': ans2['answer_id'],
101 | 'category': category})
102 | idx += 1
103 | handles.append(get_eval.remote(content, args.max_tokens))
104 | # To avoid the rate limit set by OpenAI
105 | time.sleep(NUM_SECONDS_TO_SLEEP)
106 |
107 | reviews = ray.get(handles)
108 | for idx, review in enumerate(reviews):
109 | scores = parse_score(review)
110 | js_list[idx]['content'] = review
111 | js_list[idx]['tuple'] = scores
112 | review_file.write(json.dumps(js_list[idx]) + '\n')
113 | review_file.close()
114 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_bench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 |
86 | if isinstance(inst['caption'], list):
87 | cap_str = '\n'.join(inst['caption'])
88 | else:
89 | cap_str = inst['caption']
90 |
91 | category = 'llava_bench_' + json.loads(ques_js)['category']
92 | if category in rule_dict:
93 | rule = rule_dict[category]
94 | else:
95 | assert False, f"Visual QA category not found in rule file: {category}."
96 | prompt = rule['prompt']
97 | role = rule['role']
98 | content = (f'[Context]\n{cap_str}\n\n'
99 | f'[Question]\n{ques["text"]}\n\n'
100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102 | f'[System]\n{prompt}\n\n')
103 | cur_js = {
104 | 'id': idx+1,
105 | 'question_id': ques['question_id'],
106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108 | 'category': category
109 | }
110 | if idx >= len(cur_reviews):
111 | review = get_eval(content, args.max_tokens)
112 | scores = parse_score(review)
113 | cur_js['content'] = review
114 | cur_js['tuple'] = scores
115 | review_file.write(json.dumps(cur_js) + '\n')
116 | review_file.flush()
117 | else:
118 | print(f'Skipping {idx} as we already have it.')
119 | idx += 1
120 | print(idx)
121 | review_file.close()
122 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_visual.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 | cap_str = '\n'.join(inst['captions'])
86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87 |
88 | category = json.loads(ques_js)['category']
89 | if category in rule_dict:
90 | rule = rule_dict[category]
91 | else:
92 | assert False, f"Visual QA category not found in rule file: {category}."
93 | prompt = rule['prompt']
94 | role = rule['role']
95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96 | f'[Question]\n{ques["text"]}\n\n'
97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99 | f'[System]\n{prompt}\n\n')
100 | cur_js = {
101 | 'id': idx+1,
102 | 'question_id': ques['question_id'],
103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105 | 'category': category
106 | }
107 | if idx >= len(cur_reviews):
108 | review = get_eval(content, args.max_tokens)
109 | scores = parse_score(review)
110 | cur_js['content'] = review
111 | cur_js['tuple'] = scores
112 | review_file.write(json.dumps(cur_js) + '\n')
113 | review_file.flush()
114 | else:
115 | print(f'Skipping {idx} as we already have it.')
116 | idx += 1
117 | print(idx)
118 | review_file.close()
119 |
--------------------------------------------------------------------------------
/llava/eval/eval_pope.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | def eval_pope(answers, label_file):
6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7 |
8 | for answer in answers:
9 | text = answer['text']
10 |
11 | # Only keep the first sentence
12 | if text.find('.') != -1:
13 | text = text.split('.')[0]
14 |
15 | text = text.replace(',', '')
16 | words = text.split(' ')
17 | if 'No' in words or 'not' in words or 'no' in words:
18 | answer['text'] = 'no'
19 | else:
20 | answer['text'] = 'yes'
21 |
22 | for i in range(len(label_list)):
23 | if label_list[i] == 'no':
24 | label_list[i] = 0
25 | else:
26 | label_list[i] = 1
27 |
28 | pred_list = []
29 | for answer in answers:
30 | if answer['text'] == 'no':
31 | pred_list.append(0)
32 | else:
33 | pred_list.append(1)
34 |
35 | pos = 1
36 | neg = 0
37 | yes_ratio = pred_list.count(1) / len(pred_list)
38 |
39 | TP, TN, FP, FN = 0, 0, 0, 0
40 | for pred, label in zip(pred_list, label_list):
41 | if pred == pos and label == pos:
42 | TP += 1
43 | elif pred == pos and label == neg:
44 | FP += 1
45 | elif pred == neg and label == neg:
46 | TN += 1
47 | elif pred == neg and label == pos:
48 | FN += 1
49 |
50 | print('TP\tFP\tTN\tFN\t')
51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52 |
53 | precision = float(TP) / float(TP + FP)
54 | recall = float(TP) / float(TP + FN)
55 | f1 = 2*precision*recall / (precision + recall)
56 | acc = (TP + TN) / (TP + TN + FP + FN)
57 | print('Accuracy: {}'.format(acc))
58 | print('Precision: {}'.format(precision))
59 | print('Recall: {}'.format(recall))
60 | print('F1 score: {}'.format(f1))
61 | print('Yes ratio: {}'.format(yes_ratio))
62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--annotation-dir", type=str)
67 | parser.add_argument("--question-file", type=str)
68 | parser.add_argument("--result-file", type=str)
69 | args = parser.parse_args()
70 |
71 | questions = [json.loads(line) for line in open(args.question_file)]
72 | questions = {question['question_id']: question for question in questions}
73 | answers = [json.loads(q) for q in open(args.result_file)]
74 | for file in os.listdir(args.annotation_dir):
75 | assert file.startswith('coco_pope_')
76 | assert file.endswith('.json')
77 | category = file[10:-5]
78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81 | print("====================================")
82 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--base-dir', type=str)
11 | parser.add_argument('--result-file', type=str)
12 | parser.add_argument('--output-file', type=str)
13 | parser.add_argument('--output-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return -1
36 | return random.choice(range(len(choices)))
37 |
38 |
39 | if __name__ == "__main__":
40 | args = get_args()
41 |
42 | base_dir = args.base_dir
43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
45 | predictions = [json.loads(line) for line in open(args.result_file)]
46 | predictions = {pred['question_id']: pred for pred in predictions}
47 | split_problems = {idx: problems[idx] for idx in split_indices}
48 |
49 | results = {'correct': [], 'incorrect': []}
50 | sqa_results = {}
51 | sqa_results['acc'] = None
52 | sqa_results['correct'] = None
53 | sqa_results['count'] = None
54 | sqa_results['results'] = {}
55 | sqa_results['outputs'] = {}
56 |
57 | for prob_id, prob in split_problems.items():
58 | if prob_id not in predictions:
59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60 | pred_text = 'FAILED'
61 | else:
62 | pred = predictions[prob_id]
63 | pred_text = pred['text']
64 |
65 | if pred_text in args.options:
66 | answer = pred_text
67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68 | answer = pred_text[0]
69 | else:
70 | pattern = re.compile(r'The answer is ([A-Z]).')
71 | res = pattern.findall(pred_text)
72 | if len(res) == 1:
73 | answer = res[0] # 'A', 'B', ...
74 | else:
75 | answer = "FAILED"
76 |
77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78 |
79 | analysis = {
80 | 'question_id': prob_id,
81 | 'parsed_ans': answer,
82 | 'ground_truth': args.options[prob['answer']],
83 | 'question': pred['prompt'],
84 | 'pred': pred_text,
85 | 'is_multimodal': '' in pred['prompt'],
86 | }
87 |
88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89 | sqa_results['outputs'][prob_id] = pred_text
90 |
91 | if pred_idx == prob['answer']:
92 | results['correct'].append(analysis)
93 | else:
94 | results['incorrect'].append(analysis)
95 |
96 | correct = len(results['correct'])
97 | total = len(results['correct']) + len(results['incorrect'])
98 |
99 | ###### IMG ######
100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102 | multimodal_total = multimodal_correct + multimodal_incorrect
103 | ###### IMG ######
104 |
105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106 |
107 | sqa_results['acc'] = correct / total * 100
108 | sqa_results['correct'] = correct
109 | sqa_results['count'] = total
110 |
111 | with open(args.output_file, 'w') as f:
112 | json.dump(results, f, indent=2)
113 | with open(args.output_result, 'w') as f:
114 | json.dump(sqa_results, f, indent=2)
115 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--our-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | our_predictions = [json.loads(line) for line in open(args.our_result)]
45 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49 |
50 | results = defaultdict(lambda: 0)
51 |
52 | for prob_id, prob in split_problems.items():
53 | if prob_id not in our_predictions:
54 | continue
55 | if prob_id not in gpt4_predictions:
56 | continue
57 | our_pred = our_predictions[prob_id]['text']
58 | gpt4_pred = gpt4_predictions[prob_id]
59 |
60 | pattern = re.compile(r'The answer is ([A-Z]).')
61 | our_res = pattern.findall(our_pred)
62 | if len(our_res) == 1:
63 | our_answer = our_res[0] # 'A', 'B', ...
64 | else:
65 | our_answer = "FAILED"
66 | gpt4_res = pattern.findall(gpt4_pred)
67 | if len(gpt4_res) == 1:
68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69 | else:
70 | gpt4_answer = "FAILED"
71 |
72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74 |
75 | if gpt4_answer == 'FAILED':
76 | results['gpt4_failed'] += 1
77 | # continue
78 | gpt4_pred_idx = our_pred_idx
79 | # if our_pred_idx != prob['answer']:
80 | # print(our_predictions[prob_id]['prompt'])
81 | # print('-----------------')
82 | # print(f'LECTURE: {prob["lecture"]}')
83 | # print(f'SOLUTION: {prob["solution"]}')
84 | # print('=====================')
85 | else:
86 | # continue
87 | pass
88 | # gpt4_pred_idx = our_pred_idx
89 |
90 | if gpt4_pred_idx == prob['answer']:
91 | results['correct'] += 1
92 | else:
93 | results['incorrect'] += 1
94 |
95 |
96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97 | results['correct_upperbound'] += 1
98 |
99 | correct = results['correct']
100 | total = results['correct'] + results['incorrect']
101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 |
105 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4_requery.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--requery-result', type=str)
14 | parser.add_argument('--our-result', type=str)
15 | parser.add_argument('--output-result', type=str)
16 | parser.add_argument('--split', type=str, default='test')
17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18 | return parser.parse_args()
19 |
20 |
21 | def convert_caps(results):
22 | fakecaps = []
23 | for result in results:
24 | image_id = result['question_id']
25 | caption = result['text']
26 | fakecaps.append({"image_id": int(image_id), "caption": caption})
27 | return fakecaps
28 |
29 |
30 | def get_pred_idx(prediction, choices, options):
31 | """
32 | Get the index (e.g. 2) from the prediction (e.g. 'C')
33 | """
34 | if prediction in options[:len(choices)]:
35 | return options.index(prediction)
36 | else:
37 | return random.choice(range(len(choices)))
38 |
39 |
40 | if __name__ == "__main__":
41 | args = get_args()
42 |
43 | base_dir = args.base_dir
44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
45 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
46 | our_predictions = [json.loads(line) for line in open(args.our_result)]
47 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
48 | split_problems = {idx: problems[idx] for idx in split_indices}
49 |
50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52 |
53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54 |
55 | results = defaultdict(lambda: 0)
56 |
57 | sqa_results = {}
58 | sqa_results['acc'] = None
59 | sqa_results['correct'] = None
60 | sqa_results['count'] = None
61 | sqa_results['results'] = {}
62 | sqa_results['outputs'] = {}
63 |
64 | for prob_id, prob in split_problems.items():
65 | if prob_id not in our_predictions:
66 | assert False
67 | if prob_id not in gpt4_predictions:
68 | assert False
69 | our_pred = our_predictions[prob_id]['text']
70 | gpt4_pred = gpt4_predictions[prob_id]
71 | if prob_id not in requery_predictions:
72 | results['missing_requery'] += 1
73 | requery_pred = "MISSING"
74 | else:
75 | requery_pred = requery_predictions[prob_id]['text']
76 |
77 | pattern = re.compile(r'The answer is ([A-Z]).')
78 | our_res = pattern.findall(our_pred)
79 | if len(our_res) == 1:
80 | our_answer = our_res[0] # 'A', 'B', ...
81 | else:
82 | our_answer = "FAILED"
83 |
84 | requery_res = pattern.findall(requery_pred)
85 | if len(requery_res) == 1:
86 | requery_answer = requery_res[0] # 'A', 'B', ...
87 | else:
88 | requery_answer = "FAILED"
89 |
90 | gpt4_res = pattern.findall(gpt4_pred)
91 | if len(gpt4_res) == 1:
92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
93 | else:
94 | gpt4_answer = "FAILED"
95 |
96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99 |
100 | results['total'] += 1
101 |
102 | if gpt4_answer == 'FAILED':
103 | results['gpt4_failed'] += 1
104 | if gpt4_pred_idx == prob['answer']:
105 | results['gpt4_correct'] += 1
106 | if our_pred_idx == prob['answer']:
107 | results['gpt4_ourvisual_correct'] += 1
108 | elif gpt4_pred_idx == prob['answer']:
109 | results['gpt4_correct'] += 1
110 | results['gpt4_ourvisual_correct'] += 1
111 |
112 | if our_pred_idx == prob['answer']:
113 | results['our_correct'] += 1
114 |
115 | if requery_answer == 'FAILED':
116 | sqa_results['results'][prob_id] = our_pred_idx
117 | if our_pred_idx == prob['answer']:
118 | results['requery_correct'] += 1
119 | else:
120 | sqa_results['results'][prob_id] = requery_pred_idx
121 | if requery_pred_idx == prob['answer']:
122 | results['requery_correct'] += 1
123 | else:
124 | print(f"""
125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126 | Our ({our_answer}): {our_pred}
127 | GPT-4 ({gpt4_answer}): {gpt4_pred}
128 | Requery ({requery_answer}): {requery_pred}
129 | print("=====================================")
130 | """)
131 |
132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133 | results['correct_upperbound'] += 1
134 |
135 | total = results['total']
136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142 |
143 | sqa_results['acc'] = results["requery_correct"] / total * 100
144 | sqa_results['correct'] = results["requery_correct"]
145 | sqa_results['count'] = total
146 |
147 | with open(args.output_result, 'w') as f:
148 | json.dump(sqa_results, f, indent=2)
149 |
150 |
--------------------------------------------------------------------------------
/llava/eval/eval_textvqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import re
5 |
6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--annotation-file', type=str)
12 | parser.add_argument('--result-file', type=str)
13 | parser.add_argument('--result-dir', type=str)
14 | return parser.parse_args()
15 |
16 |
17 | def prompt_processor(prompt):
18 | if prompt.startswith('OCR tokens: '):
19 | pattern = r"Question: (.*?) Short answer:"
20 | match = re.search(pattern, prompt, re.DOTALL)
21 | question = match.group(1)
22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23 | if prompt.startswith('Reference OCR token:'):
24 | question = prompt.split('\n')[1]
25 | else:
26 | question = prompt.split('\n')[0]
27 | elif len(prompt.split('\n')) == 2:
28 | question = prompt.split('\n')[0]
29 | else:
30 | assert False
31 |
32 | return question.lower()
33 |
34 |
35 | def eval_single(annotation_file, result_file):
36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37 | print(experiment_name)
38 | annotations = json.load(open(annotation_file))['data']
39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40 | results = [json.loads(line) for line in open(result_file)]
41 |
42 | pred_list = []
43 | for result in results:
44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45 | pred_list.append({
46 | "pred_answer": result['text'],
47 | "gt_answers": annotation['answers'],
48 | })
49 |
50 | evaluator = TextVQAAccuracyEvaluator()
51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52 |
53 |
54 | if __name__ == "__main__":
55 | args = get_args()
56 |
57 | if args.result_file is not None:
58 | eval_single(args.annotation_file, args.result_file)
59 |
60 | if args.result_dir is not None:
61 | for result_file in sorted(os.listdir(args.result_dir)):
62 | if not result_file.endswith('.jsonl'):
63 | print(f'Skipping {result_file}')
64 | continue
65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
66 |
--------------------------------------------------------------------------------
/llava/eval/generate_webpage_data_from_table.py:
--------------------------------------------------------------------------------
1 | """Generate json file for webpage."""
2 | import json
3 | import os
4 | import re
5 |
6 | # models = ['llama', 'alpaca', 'gpt35', 'bard']
7 | models = ['vicuna']
8 |
9 |
10 | def read_jsonl(path: str, key: str=None):
11 | data = []
12 | with open(os.path.expanduser(path)) as f:
13 | for line in f:
14 | if not line:
15 | continue
16 | data.append(json.loads(line))
17 | if key is not None:
18 | data.sort(key=lambda x: x[key])
19 | data = {item[key]: item for item in data}
20 | return data
21 |
22 |
23 | def trim_hanging_lines(s: str, n: int) -> str:
24 | s = s.strip()
25 | for _ in range(n):
26 | s = s.split('\n', 1)[1].strip()
27 | return s
28 |
29 |
30 | if __name__ == '__main__':
31 | questions = read_jsonl('table/question.jsonl', key='question_id')
32 |
33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39 |
40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45 |
46 | records = []
47 | for qid in questions.keys():
48 | r = {
49 | 'id': qid,
50 | 'category': questions[qid]['category'],
51 | 'question': questions[qid]['text'],
52 | 'answers': {
53 | # 'alpaca': alpaca_answers[qid]['text'],
54 | # 'llama': llama_answers[qid]['text'],
55 | # 'bard': bard_answers[qid]['text'],
56 | # 'gpt35': gpt35_answers[qid]['text'],
57 | 'vicuna': vicuna_answers[qid]['text'],
58 | 'ours': ours_answers[qid]['text'],
59 | },
60 | 'evaluations': {
61 | # 'alpaca': review_alpaca[qid]['text'],
62 | # 'llama': review_llama[qid]['text'],
63 | # 'bard': review_bard[qid]['text'],
64 | 'vicuna': review_vicuna[qid]['content'],
65 | # 'gpt35': review_gpt35[qid]['text'],
66 | },
67 | 'scores': {
68 | 'vicuna': review_vicuna[qid]['tuple'],
69 | # 'alpaca': review_alpaca[qid]['score'],
70 | # 'llama': review_llama[qid]['score'],
71 | # 'bard': review_bard[qid]['score'],
72 | # 'gpt35': review_gpt35[qid]['score'],
73 | },
74 | }
75 |
76 | # cleanup data
77 | cleaned_evals = {}
78 | for k, v in r['evaluations'].items():
79 | v = v.strip()
80 | lines = v.split('\n')
81 | # trim the first line if it's a pair of numbers
82 | if re.match(r'\d+[, ]+\d+', lines[0]):
83 | lines = lines[1:]
84 | v = '\n'.join(lines)
85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86 |
87 | r['evaluations'] = cleaned_evals
88 | records.append(r)
89 |
90 | # Reorder the records, this is optional
91 | for r in records:
92 | if r['id'] <= 20:
93 | r['id'] += 60
94 | else:
95 | r['id'] -= 20
96 | for r in records:
97 | if r['id'] <= 50:
98 | r['id'] += 10
99 | elif 50 < r['id'] <= 60:
100 | r['id'] -= 50
101 | for r in records:
102 | if r['id'] == 7:
103 | r['id'] = 1
104 | elif r['id'] < 7:
105 | r['id'] += 1
106 |
107 | records.sort(key=lambda x: x['id'])
108 |
109 | # Write to file
110 | with open('webpage/data.json', 'w') as f:
111 | json.dump({'questions': records, 'models': models}, f, indent=2)
112 |
--------------------------------------------------------------------------------
/llava/eval/model_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3 | import torch
4 | import os
5 | import json
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.conversation import default_conversation
10 | from llava.utils import disable_torch_init
11 |
12 |
13 | # new stopping implementation
14 | class KeywordsStoppingCriteria(StoppingCriteria):
15 | def __init__(self, keywords, tokenizer, input_ids):
16 | self.keywords = keywords
17 | self.tokenizer = tokenizer
18 | self.start_len = None
19 | self.input_ids = input_ids
20 |
21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22 | if self.start_len is None:
23 | self.start_len = self.input_ids.shape[1]
24 | else:
25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
26 | for keyword in self.keywords:
27 | if keyword in outputs:
28 | return True
29 | return False
30 |
31 |
32 | @torch.inference_mode()
33 | def eval_model(model_name, questions_file, answers_file):
34 | # Model
35 | disable_torch_init()
36 | model_name = os.path.expanduser(model_name)
37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
38 | model = AutoModelForCausalLM.from_pretrained(model_name,
39 | torch_dtype=torch.float16).cuda()
40 |
41 |
42 | ques_file = open(os.path.expanduser(questions_file), "r")
43 | ans_file = open(os.path.expanduser(answers_file), "w")
44 | for i, line in enumerate(tqdm(ques_file)):
45 | idx = json.loads(line)["question_id"]
46 | qs = json.loads(line)["text"]
47 | cat = json.loads(line)["category"]
48 | conv = default_conversation.copy()
49 | conv.append_message(conv.roles[0], qs)
50 | prompt = conv.get_prompt()
51 | inputs = tokenizer([prompt])
52 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)
54 | output_ids = model.generate(
55 | input_ids,
56 | do_sample=True,
57 | use_cache=True,
58 | temperature=0.7,
59 | max_new_tokens=1024,
60 | stopping_criteria=[stopping_criteria])
61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
62 | try:
63 | index = outputs.index(conv.sep, len(prompt))
64 | except ValueError:
65 | outputs += conv.sep
66 | index = outputs.index(conv.sep, len(prompt))
67 |
68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
69 | ans_id = shortuuid.uuid()
70 | ans_file.write(json.dumps({"question_id": idx,
71 | "text": outputs,
72 | "answer_id": ans_id,
73 | "model_id": model_name,
74 | "metadata": {}}) + "\n")
75 | ans_file.flush()
76 | ans_file.close()
77 |
78 | if __name__ == "__main__":
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
83 | args = parser.parse_args()
84 |
85 | eval_model(args.model_name, args.question_file, args.answers_file)
86 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13 |
14 | from PIL import Image
15 | import math
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def get_chunk(lst, n, k):
25 | chunks = split_list(lst, n)
26 | return chunks[k]
27 |
28 |
29 | def eval_model(args):
30 | # Model
31 | disable_torch_init()
32 | model_path = os.path.expanduser(args.model_path)
33 | model_name = get_model_name_from_path(model_path)
34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35 |
36 | meta_pth = '/opt/data/private/metas/unsplash_ISO300-_PIL_1024_x2x4_APEX.txt'
37 | img_pths = []
38 | with open(meta_pth, 'r') as f:
39 | for line in f.readlines():
40 | img_pths.append(line.split('\t')[0])
41 | f.close()
42 |
43 | img_pths = get_chunk(img_pths, args.num_chunks, args.chunk_idx)
44 |
45 | # split to batch 8
46 | img_pths = split_list(img_pths, 8)
47 |
48 |
49 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
50 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
51 | answers_file = os.path.expanduser(args.answers_file)
52 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
53 | ans_file = open(answers_file, "w")
54 | for line in tqdm(questions):
55 | idx = line["question_id"]
56 | image_file = line["image"]
57 | qs = line["text"]
58 | cur_prompt = qs
59 | if model.config.mm_use_im_start_end:
60 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
61 | else:
62 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
63 |
64 | conv = conv_templates[args.conv_mode].copy()
65 | conv.append_message(conv.roles[0], qs)
66 | conv.append_message(conv.roles[1], None)
67 | prompt = conv.get_prompt()
68 |
69 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
70 |
71 | image = Image.open(os.path.join(args.image_folder, image_file))
72 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
73 |
74 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
75 | keywords = [stop_str]
76 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
77 |
78 | with torch.inference_mode():
79 | output_ids = model.generate(
80 | input_ids,
81 | images=image_tensor.unsqueeze(0).half().cuda(),
82 | do_sample=True if args.temperature > 0 else False,
83 | temperature=args.temperature,
84 | top_p=args.top_p,
85 | num_beams=args.num_beams,
86 | # no_repeat_ngram_size=3,
87 | max_new_tokens=1024,
88 | use_cache=True)
89 |
90 | input_token_len = input_ids.shape[1]
91 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
92 | if n_diff_input_output > 0:
93 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
94 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
95 | outputs = outputs.strip()
96 | if outputs.endswith(stop_str):
97 | outputs = outputs[:-len(stop_str)]
98 | outputs = outputs.strip()
99 |
100 | ans_id = shortuuid.uuid()
101 | ans_file.write(json.dumps({"question_id": idx,
102 | "prompt": cur_prompt,
103 | "text": outputs,
104 | "answer_id": ans_id,
105 | "model_id": model_name,
106 | "metadata": {}}) + "\n")
107 | ans_file.flush()
108 | ans_file.close()
109 |
110 | if __name__ == "__main__":
111 | parser = argparse.ArgumentParser()
112 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
113 | parser.add_argument("--model-base", type=str, default=None)
114 | parser.add_argument("--image-folder", type=str, default="")
115 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
116 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
117 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
118 | parser.add_argument("--num-chunks", type=int, default=1)
119 | parser.add_argument("--chunk-idx", type=int, default=0)
120 | parser.add_argument("--temperature", type=float, default=0.2)
121 | parser.add_argument("--top_p", type=float, default=None)
122 | parser.add_argument("--num_beams", type=int, default=1)
123 | args = parser.parse_args()
124 |
125 | eval_model(args)
126 |
--------------------------------------------------------------------------------
/llava/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
1 | """Generate answers with GPT-3.5"""
2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3 | import argparse
4 | import json
5 | import os
6 | import time
7 | import concurrent.futures
8 |
9 | import openai
10 | import tqdm
11 | import shortuuid
12 |
13 | MODEL = 'gpt-3.5-turbo'
14 | MODEL_ID = 'gpt-3.5-turbo:20230327'
15 |
16 | def get_answer(question_id: int, question: str, max_tokens: int):
17 | ans = {
18 | 'answer_id': shortuuid.uuid(),
19 | 'question_id': question_id,
20 | 'model_id': MODEL_ID,
21 | }
22 | for _ in range(3):
23 | try:
24 | response = openai.ChatCompletion.create(
25 | model=MODEL,
26 | messages=[{
27 | 'role': 'system',
28 | 'content': 'You are a helpful assistant.'
29 | }, {
30 | 'role': 'user',
31 | 'content': question,
32 | }],
33 | max_tokens=max_tokens,
34 | )
35 | ans['text'] = response['choices'][0]['message']['content']
36 | return ans
37 | except Exception as e:
38 | print('[ERROR]', e)
39 | ans['text'] = '#ERROR#'
40 | time.sleep(1)
41 | return ans
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46 | parser.add_argument('-q', '--question')
47 | parser.add_argument('-o', '--output')
48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49 | args = parser.parse_args()
50 |
51 | questions_dict = {}
52 | with open(os.path.expanduser(args.question)) as f:
53 | for line in f:
54 | if not line:
55 | continue
56 | q = json.loads(line)
57 | questions_dict[q['question_id']] = q['text']
58 |
59 | answers = []
60 |
61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62 | futures = []
63 | for qid, question in questions_dict.items():
64 | future = executor.submit(get_answer, qid, question, args.max_tokens)
65 | futures.append(future)
66 |
67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68 | answers.append(future.result())
69 |
70 | answers.sort(key=lambda x: x['question_id'])
71 |
72 | with open(os.path.expanduser(args.output), 'w') as f:
73 | table = [json.dumps(ans) for ans in answers]
74 | f.write('\n'.join(table))
75 |
--------------------------------------------------------------------------------
/llava/eval/run_llava.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.model.builder import load_pretrained_model
7 | from llava.utils import disable_torch_init
8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9 |
10 | from PIL import Image
11 |
12 | import requests
13 | from PIL import Image
14 | from io import BytesIO
15 |
16 |
17 | def load_image(image_file):
18 | if image_file.startswith('http') or image_file.startswith('https'):
19 | response = requests.get(image_file)
20 | image = Image.open(BytesIO(response.content)).convert('RGB')
21 | else:
22 | image = Image.open(image_file).convert('RGB')
23 | return image
24 |
25 |
26 | def eval_model(args):
27 | # Model
28 | disable_torch_init()
29 |
30 | model_name = get_model_name_from_path(args.model_path)
31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
32 |
33 | qs = args.query
34 | if model.config.mm_use_im_start_end:
35 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
36 | else:
37 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
38 |
39 | if 'llama-2' in model_name.lower():
40 | conv_mode = "llava_llama_2"
41 | elif "v1" in model_name.lower():
42 | conv_mode = "llava_v1"
43 | elif "mpt" in model_name.lower():
44 | conv_mode = "mpt"
45 | else:
46 | conv_mode = "llava_v0"
47 |
48 | if args.conv_mode is not None and conv_mode != args.conv_mode:
49 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
50 | else:
51 | args.conv_mode = conv_mode
52 |
53 | conv = conv_templates[args.conv_mode].copy()
54 | conv.append_message(conv.roles[0], qs)
55 | conv.append_message(conv.roles[1], None)
56 | prompt = conv.get_prompt()
57 |
58 | image = load_image(args.image_file)
59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
60 |
61 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
62 |
63 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
64 | keywords = [stop_str]
65 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
66 |
67 | with torch.inference_mode():
68 | output_ids = model.generate(
69 | input_ids,
70 | images=image_tensor,
71 | do_sample=True,
72 | temperature=0.2,
73 | max_new_tokens=1024,
74 | use_cache=True,
75 | stopping_criteria=[stopping_criteria])
76 |
77 | input_token_len = input_ids.shape[1]
78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
79 | if n_diff_input_output > 0:
80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
82 | outputs = outputs.strip()
83 | if outputs.endswith(stop_str):
84 | outputs = outputs[:-len(stop_str)]
85 | outputs = outputs.strip()
86 | print(outputs)
87 |
88 | if __name__ == "__main__":
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
91 | parser.add_argument("--model-base", type=str, default=None)
92 | parser.add_argument("--image-file", type=str, required=True)
93 | parser.add_argument("--query", type=str, required=True)
94 | parser.add_argument("--conv-mode", type=str, default=None)
95 | args = parser.parse_args()
96 |
97 | eval_model(args)
98 |
--------------------------------------------------------------------------------
/llava/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 |
7 | import argparse
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11 | parser.add_argument('-d', '--dir', default=None)
12 | parser.add_argument('-v', '--version', default=None)
13 | parser.add_argument('-s', '--select', nargs='*', default=None)
14 | parser.add_argument('-f', '--files', nargs='*', default=[])
15 | parser.add_argument('-i', '--ignore', nargs='*', default=[])
16 | return parser.parse_args()
17 |
18 |
19 | if __name__ == '__main__':
20 | args = parse_args()
21 |
22 | if args.ignore is not None:
23 | args.ignore = [int(x) for x in args.ignore]
24 |
25 | if len(args.files) > 0:
26 | review_files = args.files
27 | else:
28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29 |
30 | for review_file in sorted(review_files):
31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32 | if args.select is not None and any(x not in config for x in args.select):
33 | continue
34 | if '0613' in config:
35 | version = '0613'
36 | else:
37 | version = '0314'
38 | if args.version is not None and args.version != version:
39 | continue
40 | scores = defaultdict(list)
41 | print(config)
42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43 | for review_str in f:
44 | review = json.loads(review_str)
45 | if review['question_id'] in args.ignore:
46 | continue
47 | if 'category' in review:
48 | scores[review['category']].append(review['tuple'])
49 | scores['all'].append(review['tuple'])
50 | else:
51 | if 'tuple' in review:
52 | scores['all'].append(review['tuple'])
53 | else:
54 | scores['all'].append(review['score'])
55 | for k, v in sorted(scores.items()):
56 | stats = np.asarray(v).mean(0).tolist()
57 | stats = [round(x, 3) for x in stats]
58 | # print(k, stats, round(stats[1]/stats[0]*100, 1))
59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60 | print('=================================')
61 |
--------------------------------------------------------------------------------
/llava/eval/table/model.jsonl:
--------------------------------------------------------------------------------
1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
6 |
--------------------------------------------------------------------------------
/llava/eval/table/prompt.jsonl:
--------------------------------------------------------------------------------
1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"}
2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"}
3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"}
4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/table/reviewer.jsonl:
--------------------------------------------------------------------------------
1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/alpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/alpaca.png
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/bard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/bard.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/chatgpt.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/llama.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/llama.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/vicuna.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/vicuna.jpeg
--------------------------------------------------------------------------------
/llava/eval/webpage/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3 | background-color: #f8f9fa;
4 | }
5 |
6 | .navbar-dark .navbar-nav .nav-link {
7 | color: #f1cf68;
8 | font-size: 1.1rem;
9 | padding: 0.5rem 0.6rem;
10 | }
11 |
12 | .card-header {
13 | font-weight: bold;
14 | }
15 |
16 | .card {
17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18 | transition: 0.3s;
19 | }
20 |
21 | .card:hover {
22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23 | }
24 |
25 | button {
26 | transition: background-color 0.3s;
27 | }
28 |
29 | button:hover {
30 | background-color: #007bff;
31 | }
32 |
33 | @media (max-width: 767px) {
34 | .form-row .form-group {
35 | margin-bottom: 10px;
36 | }
37 | }
38 |
39 | /* Extra styles */
40 |
41 | .expandable-card .card-text-container {
42 | max-height: 200px;
43 | overflow-y: hidden;
44 | position: relative;
45 | }
46 |
47 | .expandable-card.expanded .card-text-container {
48 | max-height: none;
49 | }
50 |
51 | .expand-btn {
52 | position: relative;
53 | display: none;
54 | background-color: rgba(255, 255, 255, 0.8);
55 | color: #510c75;
56 | border-color: transparent;
57 | }
58 |
59 | .expand-btn:hover {
60 | background-color: rgba(200, 200, 200, 0.8);
61 | text-decoration: none;
62 | border-color: transparent;
63 | color: #510c75;
64 | }
65 |
66 | .expand-btn:focus {
67 | outline: none;
68 | text-decoration: none;
69 | }
70 |
71 | .expandable-card:not(.expanded) .card-text-container:after {
72 | content: "";
73 | position: absolute;
74 | bottom: 0;
75 | left: 0;
76 | width: 100%;
77 | height: 90px;
78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79 | }
80 |
81 | .expandable-card:not(.expanded) .expand-btn {
82 | margin-top: -40px;
83 | }
84 |
85 | .card-body {
86 | padding-bottom: 5px;
87 | }
88 |
89 | .vertical-flex-layout {
90 | justify-content: center;
91 | align-items: center;
92 | height: 100%;
93 | display: flex;
94 | flex-direction: column;
95 | gap: 5px;
96 | }
97 |
98 | .figure-img {
99 | max-width: 100%;
100 | height: auto;
101 | }
102 |
103 | .adjustable-font-size {
104 | font-size: calc(0.5rem + 2vw);
105 | }
106 |
--------------------------------------------------------------------------------
/llava/llava_agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import json
4 | from tqdm import tqdm
5 |
6 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
7 | from llava.conversation import conv_templates, SeparatorStyle
8 | from llava.model.builder import load_pretrained_model
9 | from llava.utils import disable_torch_init
10 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
11 |
12 | from PIL import Image
13 | import math
14 | import time
15 | import glob as gb
16 |
17 |
18 | class LLavaAgent:
19 | def __init__(self, model_path, device='cuda', conv_mode='vicuna_v1'):
20 | self.device = device
21 | if torch.device(self.device).index is not None:
22 | device_map = {'model': torch.device(self.device).index, 'lm_head': torch.device(self.device).index}
23 | else:
24 | device_map = 'auto'
25 | model_path = os.path.expanduser(model_path)
26 | model_name = get_model_name_from_path(model_path)
27 | tokenizer, model, image_processor, context_len = load_pretrained_model(
28 | model_path, None, model_name, device=self.device, device_map=device_map)
29 | self.model = model
30 | self.image_processor = image_processor
31 | self.tokenizer = tokenizer
32 | self.context_len = context_len
33 | self.qs = 'Describe this image and its style in a very detailed manner.'
34 | self.conv_mode = conv_mode
35 |
36 | if self.model.config.mm_use_im_start_end:
37 | self.qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + self.qs
38 | else:
39 | self.qs = DEFAULT_IMAGE_TOKEN + '\n' + self.qs
40 |
41 | self.conv = conv_templates[self.conv_mode].copy()
42 | self.conv.append_message(self.conv.roles[0], self.qs)
43 | self.conv.append_message(self.conv.roles[1], None)
44 | prompt = self.conv.get_prompt()
45 | self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
46 | 0).to(self.device)
47 |
48 | def update_qs(self, qs=None):
49 | if qs is None:
50 | qs = self.qs
51 | else:
52 | if self.model.config.mm_use_im_start_end:
53 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
54 | else:
55 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
56 |
57 | self.conv = conv_templates[self.conv_mode].copy()
58 | self.conv.append_message(self.conv.roles[0], qs)
59 | self.conv.append_message(self.conv.roles[1], None)
60 | prompt = self.conv.get_prompt()
61 | self.input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
62 | 0).to(self.device)
63 |
64 | def gen_image_caption(self, imgs, temperature=0.2, top_p=0.7, num_beams=1, qs=None):
65 | '''
66 | [PIL.Image, ...]
67 | '''
68 | self.update_qs(qs)
69 |
70 | bs = len(imgs)
71 | input_ids = self.input_ids.repeat(bs, 1)
72 | img_tensor_list = []
73 | for image in imgs:
74 | _image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
75 | img_tensor_list.append(_image_tensor)
76 | image_tensor = torch.stack(img_tensor_list, dim=0).half().to(self.device)
77 | stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
78 |
79 | with torch.inference_mode():
80 | output_ids = self.model.generate(
81 | input_ids,
82 | images=image_tensor,
83 | do_sample=True if temperature > 0 else False,
84 | temperature=temperature,
85 | top_p=top_p,
86 | num_beams=num_beams,
87 | # no_repeat_ngram_size=3,
88 | max_new_tokens=512,
89 | use_cache=True)
90 |
91 | input_token_len = input_ids.shape[1]
92 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
93 |
94 | img_captions = []
95 | for output in outputs:
96 | output = output.strip()
97 | if output.endswith(stop_str):
98 | output = output[:-len(stop_str)]
99 | output = output.strip().replace('\n', ' ').replace('\r', ' ')
100 | img_captions.append(output)
101 | return img_captions
102 |
103 |
104 | if __name__ == '__main__':
105 | llava_agent = LLavaAgent("/opt/data/private/AIGC_pretrain/LLaVA1.5/llava-v1.5-13b")
106 | img = [Image.open('/opt/data/private/LV_Dataset/DiffGLV-Test-All/RealPhoto60/LQ/02.png')]
107 | caption = llava_agent.gen_image_caption(img)
108 |
--------------------------------------------------------------------------------
/llava/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 |
5 | import torch
6 | from transformers import StoppingCriteria
7 | from llava.constants import IMAGE_TOKEN_INDEX
8 |
9 |
10 | def load_image_from_base64(image):
11 | return Image.open(BytesIO(base64.b64decode(image)))
12 |
13 |
14 | def expand2square(pil_img, background_color):
15 | width, height = pil_img.size
16 | if width == height:
17 | return pil_img
18 | elif width > height:
19 | result = Image.new(pil_img.mode, (width, width), background_color)
20 | result.paste(pil_img, (0, (width - height) // 2))
21 | return result
22 | else:
23 | result = Image.new(pil_img.mode, (height, height), background_color)
24 | result.paste(pil_img, ((height - width) // 2, 0))
25 | return result
26 |
27 |
28 | def process_images(images, image_processor, model_cfg):
29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30 | new_images = []
31 | if image_aspect_ratio == 'pad':
32 | for image in images:
33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
35 | new_images.append(image)
36 | else:
37 | return image_processor(images, return_tensors='pt')['pixel_values']
38 | if all(x.shape == new_images[0].shape for x in new_images):
39 | new_images = torch.stack(new_images, dim=0)
40 | return new_images
41 |
42 |
43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
45 |
46 | def insert_separator(X, sep):
47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
48 |
49 | input_ids = []
50 | offset = 0
51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
52 | offset = 1
53 | input_ids.append(prompt_chunks[0][0])
54 |
55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
56 | input_ids.extend(x[offset:])
57 |
58 | if return_tensors is not None:
59 | if return_tensors == 'pt':
60 | return torch.tensor(input_ids, dtype=torch.long)
61 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
62 | return input_ids
63 |
64 |
65 | def get_model_name_from_path(model_path):
66 | model_path = model_path.strip("/")
67 | model_paths = model_path.split("/")
68 | if model_paths[-1].startswith('checkpoint-'):
69 | return model_paths[-2] + "_" + model_paths[-1]
70 | else:
71 | return model_paths[-1]
72 |
73 |
74 |
75 |
76 | class KeywordsStoppingCriteria(StoppingCriteria):
77 | def __init__(self, keywords, tokenizer, input_ids):
78 | self.keywords = keywords
79 | self.keyword_ids = []
80 | self.max_keyword_len = 0
81 | for keyword in keywords:
82 | cur_keyword_ids = tokenizer(keyword).input_ids
83 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
84 | cur_keyword_ids = cur_keyword_ids[1:]
85 | if len(cur_keyword_ids) > self.max_keyword_len:
86 | self.max_keyword_len = len(cur_keyword_ids)
87 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
88 | self.tokenizer = tokenizer
89 | self.start_len = input_ids.shape[1]
90 |
91 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
92 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
93 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
94 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
95 | for keyword_id in self.keyword_ids:
96 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
97 | return True
98 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
99 | for keyword in self.keywords:
100 | if keyword in outputs:
101 | return True
102 | return False
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
3 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | from torch.nn import CrossEntropyLoss
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | LlamaConfig, LlamaModel, LlamaForCausalLM
24 |
25 | from transformers.modeling_outputs import CausalLMOutputWithPast
26 |
27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaConfig(LlamaConfig):
31 | model_type = "llava"
32 |
33 |
34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35 | config_class = LlavaConfig
36 |
37 | def __init__(self, config: LlamaConfig):
38 | super(LlavaLlamaModel, self).__init__(config)
39 |
40 |
41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42 | config_class = LlavaConfig
43 |
44 | def __init__(self, config):
45 | super(LlamaForCausalLM, self).__init__(config)
46 | self.model = LlavaLlamaModel(config)
47 |
48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.model
55 |
56 | def forward(
57 | self,
58 | input_ids: torch.LongTensor = None,
59 | attention_mask: Optional[torch.Tensor] = None,
60 | past_key_values: Optional[List[torch.FloatTensor]] = None,
61 | inputs_embeds: Optional[torch.FloatTensor] = None,
62 | labels: Optional[torch.LongTensor] = None,
63 | use_cache: Optional[bool] = None,
64 | output_attentions: Optional[bool] = None,
65 | output_hidden_states: Optional[bool] = None,
66 | images: Optional[torch.FloatTensor] = None,
67 | return_dict: Optional[bool] = None,
68 | ) -> Union[Tuple, CausalLMOutputWithPast]:
69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70 | output_hidden_states = (
71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72 | )
73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74 |
75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76 |
77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78 | outputs = self.model(
79 | input_ids=input_ids,
80 | attention_mask=attention_mask,
81 | past_key_values=past_key_values,
82 | inputs_embeds=inputs_embeds,
83 | use_cache=use_cache,
84 | output_attentions=output_attentions,
85 | output_hidden_states=output_hidden_states,
86 | return_dict=return_dict
87 | )
88 |
89 | hidden_states = outputs[0]
90 | logits = self.lm_head(hidden_states)
91 |
92 | loss = None
93 | if labels is not None:
94 | # Shift so that tokens < n predict n
95 | shift_logits = logits[..., :-1, :].contiguous()
96 | shift_labels = labels[..., 1:].contiguous()
97 | # Flatten the tokens
98 | loss_fct = CrossEntropyLoss()
99 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
100 | shift_labels = shift_labels.view(-1)
101 | # Enable model/pipeline parallelism
102 | shift_labels = shift_labels.to(shift_logits.device)
103 | loss = loss_fct(shift_logits, shift_labels)
104 |
105 | if not return_dict:
106 | output = (logits,) + outputs[1:]
107 | return (loss,) + output if loss is not None else output
108 |
109 | return CausalLMOutputWithPast(
110 | loss=loss,
111 | logits=logits,
112 | past_key_values=outputs.past_key_values,
113 | hidden_states=outputs.hidden_states,
114 | attentions=outputs.attentions,
115 | )
116 |
117 | def prepare_inputs_for_generation(
118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119 | ):
120 | if past_key_values:
121 | input_ids = input_ids[:, -1:]
122 |
123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124 | if inputs_embeds is not None and past_key_values is None:
125 | model_inputs = {"inputs_embeds": inputs_embeds}
126 | else:
127 | model_inputs = {"input_ids": input_ids}
128 |
129 | model_inputs.update(
130 | {
131 | "past_key_values": past_key_values,
132 | "use_cache": kwargs.get("use_cache"),
133 | "attention_mask": attention_mask,
134 | "images": kwargs.get("images", None),
135 | }
136 | )
137 | return model_inputs
138 |
139 | AutoConfig.register("llava", LlavaConfig)
140 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
141 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple
17 | import warnings
18 |
19 | import torch
20 | import torch.nn.functional as F
21 | import math
22 |
23 | from transformers import AutoConfig, AutoModelForCausalLM
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 |
26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaMPTConfig(MPTConfig):
31 | model_type = "llava_mpt"
32 |
33 |
34 | class LlavaMPTModel(LlavaMetaModel, MPTModel):
35 | config_class = LlavaMPTConfig
36 |
37 | def __init__(self, config: MPTConfig):
38 | config.hidden_size = config.d_model
39 | super(LlavaMPTModel, self).__init__(config)
40 |
41 | def embed_tokens(self, x):
42 | return self.wte(x)
43 |
44 |
45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
46 | config_class = LlavaMPTConfig
47 | supports_gradient_checkpointing = True
48 |
49 | def __init__(self, config):
50 | super(MPTForCausalLM, self).__init__(config)
51 |
52 | if not config.tie_word_embeddings:
53 | raise ValueError('MPTForCausalLM only supports tied word embeddings')
54 | self.transformer = LlavaMPTModel(config)
55 | self.logit_scale = None
56 | if config.logit_scale is not None:
57 | logit_scale = config.logit_scale
58 | if isinstance(logit_scale, str):
59 | if logit_scale == 'inv_sqrt_d_model':
60 | logit_scale = 1 / math.sqrt(config.d_model)
61 | else:
62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
63 | self.logit_scale = logit_scale
64 |
65 | def get_model(self):
66 | return self.transformer
67 |
68 | def _set_gradient_checkpointing(self, module, value=False):
69 | if isinstance(module, LlavaMPTModel):
70 | module.gradient_checkpointing = value
71 |
72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
73 | return_dict = return_dict if return_dict is not None else self.config.return_dict
74 | use_cache = use_cache if use_cache is not None else self.config.use_cache
75 |
76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
80 | if self.logit_scale is not None:
81 | if self.logit_scale == 0:
82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
83 | logits *= self.logit_scale
84 | loss = None
85 | if labels is not None:
86 | labels = torch.roll(labels, shifts=-1)
87 | labels[:, -1] = -100
88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
90 |
91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
92 | if inputs_embeds is not None:
93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
94 | attention_mask = kwargs['attention_mask'].bool()
95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]:
96 | raise NotImplementedError('MPT does not support generation with right padding.')
97 | if self.transformer.attn_uses_sequence_id and self.training:
98 | sequence_id = torch.zeros_like(input_ids[:1])
99 | else:
100 | sequence_id = None
101 | if past_key_values is not None:
102 | input_ids = input_ids[:, -1].unsqueeze(-1)
103 | if self.transformer.prefix_lm:
104 | prefix_mask = torch.ones_like(attention_mask)
105 | if kwargs.get('use_cache') == False:
106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
107 | else:
108 | prefix_mask = None
109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
110 |
111 |
112 | AutoConfig.register("llava_mpt", LlavaMPTConfig)
113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
114 |
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/adapt_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4 | NUM_SENTINEL_TOKENS: int = 100
5 |
6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7 | """Adds sentinel tokens and padding token (if missing).
8 |
9 | Expands the tokenizer vocabulary to include sentinel tokens
10 | used in mixture-of-denoiser tasks as well as a padding token.
11 |
12 | All added tokens are added as special tokens. No tokens are
13 | added if sentinel tokens and padding token already exist.
14 | """
15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)]
16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17 | if tokenizer.pad_token is None:
18 | tokenizer.add_tokens('', special_tokens=True)
19 | tokenizer.pad_token = ''
20 | assert tokenizer.pad_token_id is not None
21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)])
22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23 | tokenizer.sentinel_token_ids = _sentinel_token_ids
24 |
25 | class AutoTokenizerForMOD(AutoTokenizer):
26 | """AutoTokenizer + Adaptation for MOD.
27 |
28 | A simple wrapper around AutoTokenizer to make instantiating
29 | an MOD-adapted tokenizer a bit easier.
30 |
31 | MOD-adapted tokenizers have sentinel tokens (e.g., ),
32 | a padding token, and a property to get the token ids of the
33 | sentinel tokens.
34 | """
35 |
36 | @classmethod
37 | def from_pretrained(cls, *args, **kwargs):
38 | """See `AutoTokenizer.from_pretrained` docstring."""
39 | tokenizer = super().from_pretrained(*args, **kwargs)
40 | adapt_tokenizer_for_denoising(tokenizer)
41 | return tokenizer
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 | from .attention import ATTN_CLASS_REGISTRY
6 | from .norm import NORM_CLASS_REGISTRY
7 |
8 | class MPTMLP(nn.Module):
9 |
10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11 | super().__init__()
12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13 | self.act = nn.GELU(approximate='none')
14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15 | self.down_proj._is_residual = True
16 |
17 | def forward(self, x):
18 | return self.down_proj(self.act(self.up_proj(x)))
19 |
20 | class MPTBlock(nn.Module):
21 |
22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23 | del kwargs
24 | super().__init__()
25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27 | self.norm_1 = norm_class(d_model, device=device)
28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29 | self.norm_2 = norm_class(d_model, device=device)
30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33 |
34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35 | a = self.norm_1(x)
36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37 | x = x + self.resid_attn_dropout(b)
38 | m = self.norm_2(x)
39 | n = self.ffn(m)
40 | x = x + self.resid_ffn_dropout(n)
41 | return (x, attn_weights, past_key_value)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/custom_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 |
6 | class SharedEmbedding(nn.Embedding):
7 |
8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
9 | if unembed:
10 | return F.linear(input, self.weight)
11 | return super().forward(input)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/meta_init_context.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import torch
3 | import torch.nn as nn
4 |
5 | @contextmanager
6 | def init_empty_weights(include_buffers: bool=False):
7 | """Meta initialization context manager.
8 |
9 | A context manager under which models are initialized with all parameters
10 | on the meta device, therefore creating an empty model. Useful when just
11 | initializing the model would blow the available RAM.
12 |
13 | Args:
14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
15 | not to also put all buffers on the meta device while initializing.
16 |
17 | Example:
18 | ```python
19 | import torch.nn as nn
20 |
21 | # Initialize a model with 100 billions parameters in no time and without using any RAM.
22 | with init_empty_weights():
23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
24 | ```
25 |
26 |
27 |
28 | Any model created under this context manager has no weights. As such you can't do something like
29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
30 |
31 |
32 | """
33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34 | yield f
35 |
36 | @contextmanager
37 | def init_on_device(device: torch.device, include_buffers: bool=False):
38 | """Device initialization context manager.
39 |
40 | A context manager under which models are initialized with all parameters
41 | on the specified device.
42 |
43 | Args:
44 | device (`torch.device`): Device to initialize all parameters on.
45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
46 | not to also put all buffers on the meta device while initializing.
47 |
48 | Example:
49 | ```python
50 | import torch.nn as nn
51 |
52 | with init_on_device(device=torch.device("cuda")):
53 | tst = nn.Liner(100, 100) # on `cuda` device
54 | ```
55 | """
56 | old_register_parameter = nn.Module.register_parameter
57 | if include_buffers:
58 | old_register_buffer = nn.Module.register_buffer
59 |
60 | def register_empty_parameter(module, name, param):
61 | old_register_parameter(module, name, param)
62 | if param is not None:
63 | param_cls = type(module._parameters[name])
64 | kwargs = module._parameters[name].__dict__
65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66 |
67 | def register_empty_buffer(module, name, buffer):
68 | old_register_buffer(module, name, buffer)
69 | if buffer is not None:
70 | module._buffers[name] = module._buffers[name].to(device)
71 | if include_buffers:
72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
73 | else:
74 | tensor_constructors_to_patch = {}
75 |
76 | def patch_tensor_constructor(fn):
77 |
78 | def wrapper(*args, **kwargs):
79 | kwargs['device'] = device
80 | return fn(*args, **kwargs)
81 | return wrapper
82 | try:
83 | nn.Module.register_parameter = register_empty_parameter
84 | if include_buffers:
85 | nn.Module.register_buffer = register_empty_buffer
86 | for torch_function_name in tensor_constructors_to_patch.keys():
87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
88 | yield
89 | finally:
90 | nn.Module.register_parameter = old_register_parameter
91 | if include_buffers:
92 | nn.Module.register_buffer = old_register_buffer
93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94 | setattr(torch, torch_function_name, old_torch_function)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _cast_if_autocast_enabled(tensor):
4 | if torch.is_autocast_enabled():
5 | if tensor.device.type == 'cuda':
6 | dtype = torch.get_autocast_gpu_dtype()
7 | elif tensor.device.type == 'cpu':
8 | dtype = torch.get_autocast_cpu_dtype()
9 | else:
10 | raise NotImplementedError()
11 | return tensor.to(dtype=dtype)
12 | return tensor
13 |
14 | class LPLayerNorm(torch.nn.LayerNorm):
15 |
16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18 |
19 | def forward(self, x):
20 | module_device = x.device
21 | downcast_x = _cast_if_autocast_enabled(x)
22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24 | with torch.autocast(enabled=False, device_type=module_device.type):
25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26 |
27 | def rms_norm(x, weight=None, eps=1e-05):
28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29 | if weight is not None:
30 | return output * weight
31 | return output
32 |
33 | class RMSNorm(torch.nn.Module):
34 |
35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36 | super().__init__()
37 | self.eps = eps
38 | if weight:
39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
40 | else:
41 | self.register_parameter('weight', None)
42 |
43 | def forward(self, x):
44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45 |
46 | class LPRMSNorm(RMSNorm):
47 |
48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50 |
51 | def forward(self, x):
52 | downcast_x = _cast_if_autocast_enabled(x)
53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54 | with torch.autocast(enabled=False, device_type=x.device.type):
55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10 |
11 | raise ValueError(f'Unknown vision tower: {vision_tower}')
12 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 | from CKPT_PTH import LLAVA_CLIP_PATH
6 |
7 |
8 | class CLIPVisionTower(nn.Module):
9 | def __init__(self, vision_tower, args, delay_load=False):
10 | super().__init__()
11 |
12 | self.is_loaded = False
13 |
14 | self.vision_tower_name = vision_tower
15 | print(f'Loading vision tower: {self.vision_tower_name}')
16 | self.select_layer = args.mm_vision_select_layer
17 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
18 |
19 | if not delay_load:
20 | self.load_model()
21 | else:
22 | # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23 | self.cfg_only = CLIPVisionConfig.from_pretrained(LLAVA_CLIP_PATH)
24 |
25 | def load_model(self):
26 | self.image_processor = CLIPImageProcessor.from_pretrained(LLAVA_CLIP_PATH)
27 | self.vision_tower = CLIPVisionModel.from_pretrained(LLAVA_CLIP_PATH)
28 | self.vision_tower.requires_grad_(False)
29 |
30 | self.is_loaded = True
31 |
32 | def feature_select(self, image_forward_outs):
33 | image_features = image_forward_outs.hidden_states[self.select_layer]
34 | if self.select_feature == 'patch':
35 | image_features = image_features[:, 1:]
36 | elif self.select_feature == 'cls_patch':
37 | image_features = image_features
38 | else:
39 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
40 | return image_features
41 |
42 | @torch.no_grad()
43 | def forward(self, images):
44 | if type(images) is list:
45 | image_features = []
46 | for image in images:
47 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
48 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
49 | image_features.append(image_feature)
50 | else:
51 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
52 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
53 |
54 | return image_features
55 |
56 | @property
57 | def dummy_feature(self):
58 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
59 |
60 | @property
61 | def dtype(self):
62 | return self.vision_tower.dtype
63 |
64 | @property
65 | def device(self):
66 | return self.vision_tower.device
67 |
68 | @property
69 | def config(self):
70 | if self.is_loaded:
71 | return self.vision_tower.config
72 | else:
73 | return self.cfg_only
74 |
75 | @property
76 | def hidden_size(self):
77 | return self.config.hidden_size
78 |
79 | @property
80 | def num_patches(self):
81 | return (self.config.image_size // self.config.patch_size) ** 2
82 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/serve/__init__.py
--------------------------------------------------------------------------------
/llava/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.model.builder import load_pretrained_model
7 | from llava.utils import disable_torch_init
8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9 |
10 | from PIL import Image
11 |
12 | import requests
13 | from PIL import Image
14 | from io import BytesIO
15 | from transformers import TextStreamer
16 |
17 |
18 | def load_image(image_file):
19 | if image_file.startswith('http://') or image_file.startswith('https://'):
20 | response = requests.get(image_file)
21 | image = Image.open(BytesIO(response.content)).convert('RGB')
22 | else:
23 | image = Image.open(image_file).convert('RGB')
24 | return image
25 |
26 |
27 | def main(args):
28 | # Model
29 | disable_torch_init()
30 |
31 | model_name = get_model_name_from_path(args.model_path)
32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33 |
34 | if 'llama-2' in model_name.lower():
35 | conv_mode = "llava_llama_2"
36 | elif "v1" in model_name.lower():
37 | conv_mode = "llava_v1"
38 | elif "mpt" in model_name.lower():
39 | conv_mode = "mpt"
40 | else:
41 | conv_mode = "llava_v0"
42 |
43 | if args.conv_mode is not None and conv_mode != args.conv_mode:
44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
45 | else:
46 | args.conv_mode = conv_mode
47 |
48 | conv = conv_templates[args.conv_mode].copy()
49 | if "mpt" in model_name.lower():
50 | roles = ('user', 'assistant')
51 | else:
52 | roles = conv.roles
53 |
54 | image = load_image(args.image_file)
55 | # Similar operation in model_worker.py
56 | image_tensor = process_images([image], image_processor, args)
57 | if type(image_tensor) is list:
58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
59 | else:
60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16)
61 |
62 | while True:
63 | try:
64 | inp = input(f"{roles[0]}: ")
65 | except EOFError:
66 | inp = ""
67 | if not inp:
68 | print("exit...")
69 | break
70 |
71 | print(f"{roles[1]}: ", end="")
72 |
73 | if image is not None:
74 | # first message
75 | if model.config.mm_use_im_start_end:
76 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
77 | else:
78 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
79 | conv.append_message(conv.roles[0], inp)
80 | image = None
81 | else:
82 | # later messages
83 | conv.append_message(conv.roles[0], inp)
84 | conv.append_message(conv.roles[1], None)
85 | prompt = conv.get_prompt()
86 |
87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
89 | keywords = [stop_str]
90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
92 |
93 | with torch.inference_mode():
94 | output_ids = model.generate(
95 | input_ids,
96 | images=image_tensor,
97 | do_sample=True,
98 | temperature=args.temperature,
99 | max_new_tokens=args.max_new_tokens,
100 | streamer=streamer,
101 | use_cache=True,
102 | stopping_criteria=[stopping_criteria])
103 |
104 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
105 | conv.messages[-1][-1] = outputs
106 |
107 | if args.debug:
108 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
109 |
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
114 | parser.add_argument("--model-base", type=str, default=None)
115 | parser.add_argument("--image-file", type=str, required=True)
116 | parser.add_argument("--device", type=str, default="cuda")
117 | parser.add_argument("--conv-mode", type=str, default=None)
118 | parser.add_argument("--temperature", type=float, default=0.2)
119 | parser.add_argument("--max-new-tokens", type=int, default=512)
120 | parser.add_argument("--load-8bit", action="store_true")
121 | parser.add_argument("--load-4bit", action="store_true")
122 | parser.add_argument("--debug", action="store_true")
123 | parser.add_argument("--image-aspect-ratio", type=str, default='pad')
124 | args = parser.parse_args()
125 | main(args)
126 |
--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default=
59 | "Tell me a story with more than 1000 words.")
60 | args = parser.parse_args()
61 |
62 | main()
63 |
--------------------------------------------------------------------------------
/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 |
--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True)
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/options/SUPIR_v0.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: SUPIR.models.SUPIR_model.SUPIRModel
3 | params:
4 | ae_dtype: bf16
5 | diffusion_dtype: fp16
6 | scale_factor: 0.13025
7 | disable_first_stage_autocast: True
8 | network_wrapper: sgm.modules.diffusionmodules.wrappers.ControlWrapper
9 |
10 | denoiser_config:
11 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiserWithControl
12 | params:
13 | num_idx: 1000
14 | weighting_config:
15 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
16 | scaling_config:
17 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
18 | discretization_config:
19 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
20 |
21 | control_stage_config:
22 | target: SUPIR.modules.SUPIR_v0.GLVControl
23 | params:
24 | adm_in_channels: 2816
25 | num_classes: sequential
26 | use_checkpoint: True
27 | in_channels: 4
28 | out_channels: 4
29 | model_channels: 320
30 | attention_resolutions: [4, 2]
31 | num_res_blocks: 2
32 | channel_mult: [1, 2, 4]
33 | num_head_channels: 64
34 | use_spatial_transformer: True
35 | use_linear_in_transformer: True
36 | transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
37 | # transformer_depth: [1, 1, 4]
38 | context_dim: 2048
39 | spatial_transformer_attn_type: softmax-xformers
40 | legacy: False
41 | input_upscale: 1
42 |
43 | network_config:
44 | target: SUPIR.modules.SUPIR_v0.LightGLVUNet
45 | params:
46 | mode: XL-base
47 | project_type: ZeroSFT
48 | project_channel_scale: 2
49 | adm_in_channels: 2816
50 | num_classes: sequential
51 | use_checkpoint: True
52 | in_channels: 4
53 | out_channels: 4
54 | model_channels: 320
55 | attention_resolutions: [4, 2]
56 | num_res_blocks: 2
57 | channel_mult: [1, 2, 4]
58 | num_head_channels: 64
59 | use_spatial_transformer: True
60 | use_linear_in_transformer: True
61 | transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
62 | context_dim: 2048
63 | spatial_transformer_attn_type: softmax-xformers
64 | legacy: False
65 |
66 | conditioner_config:
67 | target: sgm.modules.GeneralConditionerWithControl
68 | params:
69 | emb_models:
70 | # crossattn cond
71 | - is_trainable: False
72 | input_key: txt
73 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
74 | params:
75 | layer: hidden
76 | layer_idx: 11
77 | # crossattn and vector cond
78 | - is_trainable: False
79 | input_key: txt
80 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
81 | params:
82 | arch: ViT-bigG-14
83 | version: laion2b_s39b_b160k
84 | freeze: True
85 | layer: penultimate
86 | always_return_pooled: True
87 | legacy: False
88 | # vector cond
89 | - is_trainable: False
90 | input_key: original_size_as_tuple
91 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
92 | params:
93 | outdim: 256 # multiplied by two
94 | # vector cond
95 | - is_trainable: False
96 | input_key: crop_coords_top_left
97 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
98 | params:
99 | outdim: 256 # multiplied by two
100 | # vector cond
101 | - is_trainable: False
102 | input_key: target_size_as_tuple
103 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
104 | params:
105 | outdim: 256 # multiplied by two
106 |
107 | first_stage_config:
108 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
109 | params:
110 | ckpt_path: ~
111 | embed_dim: 4
112 | monitor: val/rec_loss
113 | ddconfig:
114 | attn_type: vanilla-xformers
115 | double_z: true
116 | z_channels: 4
117 | resolution: 256
118 | in_channels: 3
119 | out_ch: 3
120 | ch: 128
121 | ch_mult: [ 1, 2, 4, 4 ]
122 | num_res_blocks: 2
123 | attn_resolutions: [ ]
124 | dropout: 0.0
125 | lossconfig:
126 | target: torch.nn.Identity
127 |
128 | sampler_config:
129 | target: sgm.modules.diffusionmodules.sampling.RestoreEDMSampler
130 | params:
131 | num_steps: 100
132 | restore_cfg: 4.0
133 | s_churn: 0
134 | s_noise: 1.003
135 | discretization_config:
136 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
137 | guider_config:
138 | target: sgm.modules.diffusionmodules.guiders.LinearCFG
139 | params:
140 | scale: 7.5
141 | scale_min: 4.0
142 |
143 | p_p:
144 | 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera,
145 | hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing,
146 | skin pore detailing, hyper sharpness, perfect without deformations.'
147 | n_p:
148 | 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render,
149 | unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature,
150 | jpeg artifacts, deformed, lowres, over-smooth'
151 |
152 | SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors
153 | SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt
154 | SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt
155 | SUPIR_CKPT: ~
156 |
157 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi==0.95.1
2 | gradio==4.16.0
3 | gradio_imageslider==0.0.17
4 | Markdown==3.4.1
5 | numpy==1.24.2
6 | requests==2.28.2
7 | sentencepiece==0.1.98
8 | tokenizers==0.13.3
9 | torch>=2.1.0
10 | torchvision>=0.16.0
11 | uvicorn==0.21.1
12 | wandb==0.14.0
13 | httpx==0.24.0
14 | transformers==4.28.1
15 | accelerate==0.18.0
16 | scikit-learn==1.2.2
17 | sentencepiece==0.1.98
18 | einops==0.7.0
19 | einops-exts==0.0.4
20 | timm==0.9.8
21 | gradio_client==0.1.3
22 | openai-clip==1.0.1
23 | fsspec==2023.4.0
24 | kornia==0.6.9
25 | matplotlib==3.7.1
26 | ninja==1.11.1
27 | omegaconf==2.3.0
28 | open-clip-torch==2.17.1
29 | opencv-python==4.7.0.72
30 | pandas==2.0.1
31 | Pillow==9.4.0
32 | pytorch-lightning==2.1.2
33 | PyYAML==6.0
34 | scipy==1.9.1
35 | tqdm==4.65.0
36 | triton==2.1.0
37 | urllib3==1.26.15
38 | webdataset==0.2.48
39 | xformers>=0.0.20
40 |
--------------------------------------------------------------------------------
/sgm/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import AutoencodingEngine, DiffusionEngine
2 | from .util import get_configs_path, instantiate_from_config
3 |
4 | __version__ = "0.1.0"
5 |
--------------------------------------------------------------------------------
/sgm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 |
9 | def __init__(
10 | self,
11 | warm_up_steps,
12 | lr_min,
13 | lr_max,
14 | lr_start,
15 | max_decay_steps,
16 | verbosity_interval=0,
17 | ):
18 | self.lr_warm_up_steps = warm_up_steps
19 | self.lr_start = lr_start
20 | self.lr_min = lr_min
21 | self.lr_max = lr_max
22 | self.lr_max_decay_steps = max_decay_steps
23 | self.last_lr = 0.0
24 | self.verbosity_interval = verbosity_interval
25 |
26 | def schedule(self, n, **kwargs):
27 | if self.verbosity_interval > 0:
28 | if n % self.verbosity_interval == 0:
29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30 | if n < self.lr_warm_up_steps:
31 | lr = (
32 | self.lr_max - self.lr_start
33 | ) / self.lr_warm_up_steps * n + self.lr_start
34 | self.last_lr = lr
35 | return lr
36 | else:
37 | t = (n - self.lr_warm_up_steps) / (
38 | self.lr_max_decay_steps - self.lr_warm_up_steps
39 | )
40 | t = min(t, 1.0)
41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42 | 1 + np.cos(t * np.pi)
43 | )
44 | self.last_lr = lr
45 | return lr
46 |
47 | def __call__(self, n, **kwargs):
48 | return self.schedule(n, **kwargs)
49 |
50 |
51 | class LambdaWarmUpCosineScheduler2:
52 | """
53 | supports repeated iterations, configurable via lists
54 | note: use with a base_lr of 1.0.
55 | """
56 |
57 | def __init__(
58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59 | ):
60 | assert (
61 | len(warm_up_steps)
62 | == len(f_min)
63 | == len(f_max)
64 | == len(f_start)
65 | == len(cycle_lengths)
66 | )
67 | self.lr_warm_up_steps = warm_up_steps
68 | self.f_start = f_start
69 | self.f_min = f_min
70 | self.f_max = f_max
71 | self.cycle_lengths = cycle_lengths
72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73 | self.last_f = 0.0
74 | self.verbosity_interval = verbosity_interval
75 |
76 | def find_in_interval(self, n):
77 | interval = 0
78 | for cl in self.cum_cycles[1:]:
79 | if n <= cl:
80 | return interval
81 | interval += 1
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0:
88 | print(
89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90 | f"current cycle {cycle}"
91 | )
92 | if n < self.lr_warm_up_steps[cycle]:
93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94 | cycle
95 | ] * n + self.f_start[cycle]
96 | self.last_f = f
97 | return f
98 | else:
99 | t = (n - self.lr_warm_up_steps[cycle]) / (
100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101 | )
102 | t = min(t, 1.0)
103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104 | 1 + np.cos(t * np.pi)
105 | )
106 | self.last_f = f
107 | return f
108 |
109 | def __call__(self, n, **kwargs):
110 | return self.schedule(n, **kwargs)
111 |
112 |
113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114 | def schedule(self, n, **kwargs):
115 | cycle = self.find_in_interval(n)
116 | n = n - self.cum_cycles[cycle]
117 | if self.verbosity_interval > 0:
118 | if n % self.verbosity_interval == 0:
119 | print(
120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121 | f"current cycle {cycle}"
122 | )
123 |
124 | if n < self.lr_warm_up_steps[cycle]:
125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126 | cycle
127 | ] * n + self.f_start[cycle]
128 | self.last_f = f
129 | return f
130 | else:
131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132 | self.cycle_lengths[cycle] - n
133 | ) / (self.cycle_lengths[cycle])
134 | self.last_f = f
135 | return f
136 |
--------------------------------------------------------------------------------
/sgm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .autoencoder import AutoencodingEngine
2 | from .diffusion import DiffusionEngine
3 |
--------------------------------------------------------------------------------
/sgm/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoders.modules import GeneralConditioner
2 | from .encoders.modules import GeneralConditionerWithControl
3 | from .encoders.modules import PreparedConditioner
4 |
5 | UNCONDITIONAL_CONFIG = {
6 | "target": "sgm.modules.GeneralConditioner",
7 | "params": {"emb_models": []},
8 | }
9 |
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/__init__.py
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/lpips/__init__.py
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/loss/.gitignore:
--------------------------------------------------------------------------------
1 | vgg.pth
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/loss/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/lpips/loss/__init__.py
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/loss/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | from collections import namedtuple
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from ..util import get_ckpt_path
10 |
11 |
12 | class LPIPS(nn.Module):
13 | # Learned perceptual metric
14 | def __init__(self, use_dropout=True):
15 | super().__init__()
16 | self.scaling_layer = ScalingLayer()
17 | self.chns = [64, 128, 256, 512, 512] # vg16 features
18 | self.net = vgg16(pretrained=True, requires_grad=False)
19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24 | self.load_from_pretrained()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def load_from_pretrained(self, name="vgg_lpips"):
29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30 | self.load_state_dict(
31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32 | )
33 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
34 |
35 | @classmethod
36 | def from_pretrained(cls, name="vgg_lpips"):
37 | if name != "vgg_lpips":
38 | raise NotImplementedError
39 | model = cls()
40 | ckpt = get_ckpt_path(name)
41 | model.load_state_dict(
42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43 | )
44 | return model
45 |
46 | def forward(self, input, target):
47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
49 | feats0, feats1, diffs = {}, {}, {}
50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51 | for kk in range(len(self.chns)):
52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53 | outs1[kk]
54 | )
55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56 |
57 | res = [
58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59 | for kk in range(len(self.chns))
60 | ]
61 | val = res[0]
62 | for l in range(1, len(self.chns)):
63 | val += res[l]
64 | return val
65 |
66 |
67 | class ScalingLayer(nn.Module):
68 | def __init__(self):
69 | super(ScalingLayer, self).__init__()
70 | self.register_buffer(
71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72 | )
73 | self.register_buffer(
74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75 | )
76 |
77 | def forward(self, inp):
78 | return (inp - self.shift) / self.scale
79 |
80 |
81 | class NetLinLayer(nn.Module):
82 | """A single linear layer which does a 1x1 conv"""
83 |
84 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
85 | super(NetLinLayer, self).__init__()
86 | layers = (
87 | [
88 | nn.Dropout(),
89 | ]
90 | if (use_dropout)
91 | else []
92 | )
93 | layers += [
94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95 | ]
96 | self.model = nn.Sequential(*layers)
97 |
98 |
99 | class vgg16(torch.nn.Module):
100 | def __init__(self, requires_grad=False, pretrained=True):
101 | super(vgg16, self).__init__()
102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103 | self.slice1 = torch.nn.Sequential()
104 | self.slice2 = torch.nn.Sequential()
105 | self.slice3 = torch.nn.Sequential()
106 | self.slice4 = torch.nn.Sequential()
107 | self.slice5 = torch.nn.Sequential()
108 | self.N_slices = 5
109 | for x in range(4):
110 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(4, 9):
112 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(9, 16):
114 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(16, 23):
116 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
117 | for x in range(23, 30):
118 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
119 | if not requires_grad:
120 | for param in self.parameters():
121 | param.requires_grad = False
122 |
123 | def forward(self, X):
124 | h = self.slice1(X)
125 | h_relu1_2 = h
126 | h = self.slice2(h)
127 | h_relu2_2 = h
128 | h = self.slice3(h)
129 | h_relu3_3 = h
130 | h = self.slice4(h)
131 | h_relu4_3 = h
132 | h = self.slice5(h)
133 | h_relu5_3 = h
134 | vgg_outputs = namedtuple(
135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136 | )
137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138 | return out
139 |
140 |
141 | def normalize_tensor(x, eps=1e-10):
142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143 | return x / (norm_factor + eps)
144 |
145 |
146 | def spatial_average(x, keepdim=True):
147 | return x.mean([2, 3], keepdim=keepdim)
148 |
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/model/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR pix2pix --------------------------------
27 | BSD License
28 |
29 | For pix2pix software
30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | ----------------------------- LICENSE FOR DCGAN --------------------------------
44 | BSD License
45 |
46 | For dcgan.torch software
47 |
48 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
49 |
50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51 |
52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53 |
54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55 |
56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57 |
58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/lpips/model/__init__.py
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/model/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn as nn
4 |
5 | from ..util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find("BatchNorm") != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 |
22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23 | """Construct a PatchGAN discriminator
24 | Parameters:
25 | input_nc (int) -- the number of channels in input images
26 | ndf (int) -- the number of filters in the last conv layer
27 | n_layers (int) -- the number of conv layers in the discriminator
28 | norm_layer -- normalization layer
29 | """
30 | super(NLayerDiscriminator, self).__init__()
31 | if not use_actnorm:
32 | norm_layer = nn.BatchNorm2d
33 | else:
34 | norm_layer = ActNorm
35 | if (
36 | type(norm_layer) == functools.partial
37 | ): # no need to use bias as BatchNorm2d has affine parameters
38 | use_bias = norm_layer.func != nn.BatchNorm2d
39 | else:
40 | use_bias = norm_layer != nn.BatchNorm2d
41 |
42 | kw = 4
43 | padw = 1
44 | sequence = [
45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46 | nn.LeakyReLU(0.2, True),
47 | ]
48 | nf_mult = 1
49 | nf_mult_prev = 1
50 | for n in range(1, n_layers): # gradually increase the number of filters
51 | nf_mult_prev = nf_mult
52 | nf_mult = min(2**n, 8)
53 | sequence += [
54 | nn.Conv2d(
55 | ndf * nf_mult_prev,
56 | ndf * nf_mult,
57 | kernel_size=kw,
58 | stride=2,
59 | padding=padw,
60 | bias=use_bias,
61 | ),
62 | norm_layer(ndf * nf_mult),
63 | nn.LeakyReLU(0.2, True),
64 | ]
65 |
66 | nf_mult_prev = nf_mult
67 | nf_mult = min(2**n_layers, 8)
68 | sequence += [
69 | nn.Conv2d(
70 | ndf * nf_mult_prev,
71 | ndf * nf_mult,
72 | kernel_size=kw,
73 | stride=1,
74 | padding=padw,
75 | bias=use_bias,
76 | ),
77 | norm_layer(ndf * nf_mult),
78 | nn.LeakyReLU(0.2, True),
79 | ]
80 |
81 | sequence += [
82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83 | ] # output 1 channel prediction map
84 | self.main = nn.Sequential(*sequence)
85 |
86 | def forward(self, input):
87 | """Standard forward."""
88 | return self.main(input)
89 |
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 |
4 | import requests
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 |
9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10 |
11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12 |
13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14 |
15 |
16 | def download(url, local_path, chunk_size=1024):
17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18 | with requests.get(url, stream=True) as r:
19 | total_size = int(r.headers.get("content-length", 0))
20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21 | with open(local_path, "wb") as f:
22 | for data in r.iter_content(chunk_size=chunk_size):
23 | if data:
24 | f.write(data)
25 | pbar.update(chunk_size)
26 |
27 |
28 | def md5_hash(path):
29 | with open(path, "rb") as f:
30 | content = f.read()
31 | return hashlib.md5(content).hexdigest()
32 |
33 |
34 | def get_ckpt_path(name, root, check=False):
35 | assert name in URL_MAP
36 | path = os.path.join(root, CKPT_MAP[name])
37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39 | download(URL_MAP[name], path)
40 | md5 = md5_hash(path)
41 | assert md5 == MD5_MAP[name], md5
42 | return path
43 |
44 |
45 | class ActNorm(nn.Module):
46 | def __init__(
47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False
48 | ):
49 | assert affine
50 | super().__init__()
51 | self.logdet = logdet
52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54 | self.allow_reverse_init = allow_reverse_init
55 |
56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57 |
58 | def initialize(self, input):
59 | with torch.no_grad():
60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61 | mean = (
62 | flatten.mean(1)
63 | .unsqueeze(1)
64 | .unsqueeze(2)
65 | .unsqueeze(3)
66 | .permute(1, 0, 2, 3)
67 | )
68 | std = (
69 | flatten.std(1)
70 | .unsqueeze(1)
71 | .unsqueeze(2)
72 | .unsqueeze(3)
73 | .permute(1, 0, 2, 3)
74 | )
75 |
76 | self.loc.data.copy_(-mean)
77 | self.scale.data.copy_(1 / (std + 1e-6))
78 |
79 | def forward(self, input, reverse=False):
80 | if reverse:
81 | return self.reverse(input)
82 | if len(input.shape) == 2:
83 | input = input[:, :, None, None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | _, _, height, width = input.shape
89 |
90 | if self.training and self.initialized.item() == 0:
91 | self.initialize(input)
92 | self.initialized.fill_(1)
93 |
94 | h = self.scale * (input + self.loc)
95 |
96 | if squeeze:
97 | h = h.squeeze(-1).squeeze(-1)
98 |
99 | if self.logdet:
100 | log_abs = torch.log(torch.abs(self.scale))
101 | logdet = height * width * torch.sum(log_abs)
102 | logdet = logdet * torch.ones(input.shape[0]).to(input)
103 | return h, logdet
104 |
105 | return h
106 |
107 | def reverse(self, output):
108 | if self.training and self.initialized.item() == 0:
109 | if not self.allow_reverse_init:
110 | raise RuntimeError(
111 | "Initializing ActNorm in reverse direction is "
112 | "disabled by default. Use allow_reverse_init=True to enable."
113 | )
114 | else:
115 | self.initialize(output)
116 | self.initialized.fill_(1)
117 |
118 | if len(output.shape) == 2:
119 | output = output[:, :, None, None]
120 | squeeze = True
121 | else:
122 | squeeze = False
123 |
124 | h = output / self.scale - self.loc
125 |
126 | if squeeze:
127 | h = h.squeeze(-1).squeeze(-1)
128 | return h
129 |
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/lpips/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def hinge_d_loss(logits_real, logits_fake):
6 | loss_real = torch.mean(F.relu(1.0 - logits_real))
7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8 | d_loss = 0.5 * (loss_real + loss_fake)
9 | return d_loss
10 |
11 |
12 | def vanilla_d_loss(logits_real, logits_fake):
13 | d_loss = 0.5 * (
14 | torch.mean(torch.nn.functional.softplus(-logits_real))
15 | + torch.mean(torch.nn.functional.softplus(logits_fake))
16 | )
17 | return d_loss
18 |
--------------------------------------------------------------------------------
/sgm/modules/autoencoding/regularizers/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from ....modules.distributions.distributions import DiagonalGaussianDistribution
9 |
10 |
11 | class AbstractRegularizer(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
16 | raise NotImplementedError()
17 |
18 | @abstractmethod
19 | def get_trainable_parameters(self) -> Any:
20 | raise NotImplementedError()
21 |
22 |
23 | class DiagonalGaussianRegularizer(AbstractRegularizer):
24 | def __init__(self, sample: bool = True):
25 | super().__init__()
26 | self.sample = sample
27 |
28 | def get_trainable_parameters(self) -> Any:
29 | yield from ()
30 |
31 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
32 | log = dict()
33 | posterior = DiagonalGaussianDistribution(z)
34 | if self.sample:
35 | z = posterior.sample()
36 | else:
37 | z = posterior.mode()
38 | kl_loss = posterior.kl()
39 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
40 | log["kl_loss"] = kl_loss
41 | return z, log
42 |
43 |
44 | def measure_perplexity(predicted_indices, num_centroids):
45 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
46 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
47 | encodings = (
48 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
49 | )
50 | avg_probs = encodings.mean(0)
51 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
52 | cluster_use = torch.sum(avg_probs > 0)
53 | return perplexity, cluster_use
54 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 | from .denoiser import Denoiser
2 | from .discretizer import Discretization
3 | from .loss import StandardDiffusionLoss
4 | from .model import Decoder, Encoder, Model
5 | from .openaimodel import UNetModel
6 | from .sampling import BaseDiffusionSampler
7 | from .wrappers import OpenAIWrapper
8 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/denoiser.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from ...util import append_dims, instantiate_from_config
4 |
5 |
6 | class Denoiser(nn.Module):
7 | def __init__(self, weighting_config, scaling_config):
8 | super().__init__()
9 |
10 | self.weighting = instantiate_from_config(weighting_config)
11 | self.scaling = instantiate_from_config(scaling_config)
12 |
13 | def possibly_quantize_sigma(self, sigma):
14 | return sigma
15 |
16 | def possibly_quantize_c_noise(self, c_noise):
17 | return c_noise
18 |
19 | def w(self, sigma):
20 | return self.weighting(sigma)
21 |
22 | def __call__(self, network, input, sigma, cond):
23 | sigma = self.possibly_quantize_sigma(sigma)
24 | sigma_shape = sigma.shape
25 | sigma = append_dims(sigma, input.ndim)
26 | c_skip, c_out, c_in, c_noise = self.scaling(sigma)
27 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
28 | return network(input * c_in, c_noise, cond) * c_out + input * c_skip
29 |
30 |
31 | class DiscreteDenoiser(Denoiser):
32 | def __init__(
33 | self,
34 | weighting_config,
35 | scaling_config,
36 | num_idx,
37 | discretization_config,
38 | do_append_zero=False,
39 | quantize_c_noise=True,
40 | flip=True,
41 | ):
42 | super().__init__(weighting_config, scaling_config)
43 | sigmas = instantiate_from_config(discretization_config)(
44 | num_idx, do_append_zero=do_append_zero, flip=flip
45 | )
46 | self.register_buffer("sigmas", sigmas)
47 | self.quantize_c_noise = quantize_c_noise
48 |
49 | def sigma_to_idx(self, sigma):
50 | dists = sigma - self.sigmas[:, None]
51 | return dists.abs().argmin(dim=0).view(sigma.shape)
52 |
53 | def idx_to_sigma(self, idx):
54 | return self.sigmas[idx]
55 |
56 | def possibly_quantize_sigma(self, sigma):
57 | return self.idx_to_sigma(self.sigma_to_idx(sigma))
58 |
59 | def possibly_quantize_c_noise(self, c_noise):
60 | if self.quantize_c_noise:
61 | return self.sigma_to_idx(c_noise)
62 | else:
63 | return c_noise
64 |
65 |
66 | class DiscreteDenoiserWithControl(DiscreteDenoiser):
67 | def __call__(self, network, input, sigma, cond, control_scale):
68 | sigma = self.possibly_quantize_sigma(sigma)
69 | sigma_shape = sigma.shape
70 | sigma = append_dims(sigma, input.ndim)
71 | c_skip, c_out, c_in, c_noise = self.scaling(sigma)
72 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
73 | return network(input * c_in, c_noise, cond, control_scale) * c_out + input * c_skip
74 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/denoiser_scaling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class EDMScaling:
5 | def __init__(self, sigma_data=0.5):
6 | self.sigma_data = sigma_data
7 |
8 | def __call__(self, sigma):
9 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
10 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
11 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
12 | c_noise = 0.25 * sigma.log()
13 | return c_skip, c_out, c_in, c_noise
14 |
15 |
16 | class EpsScaling:
17 | def __call__(self, sigma):
18 | c_skip = torch.ones_like(sigma, device=sigma.device)
19 | c_out = -sigma
20 | c_in = 1 / (sigma**2 + 1.0) ** 0.5
21 | c_noise = sigma.clone()
22 | return c_skip, c_out, c_in, c_noise
23 |
24 |
25 | class VScaling:
26 | def __call__(self, sigma):
27 | c_skip = 1.0 / (sigma**2 + 1.0)
28 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5
29 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
30 | c_noise = sigma.clone()
31 | return c_skip, c_out, c_in, c_noise
32 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/denoiser_weighting.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class UnitWeighting:
4 | def __call__(self, sigma):
5 | return torch.ones_like(sigma, device=sigma.device)
6 |
7 |
8 | class EDMWeighting:
9 | def __init__(self, sigma_data=0.5):
10 | self.sigma_data = sigma_data
11 |
12 | def __call__(self, sigma):
13 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
14 |
15 |
16 | class VWeighting(EDMWeighting):
17 | def __init__(self):
18 | super().__init__(sigma_data=1.0)
19 |
20 |
21 | class EpsWeighting:
22 | def __call__(self, sigma):
23 | return sigma**-2
24 |
25 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/discretizer.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from functools import partial
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from ...modules.diffusionmodules.util import make_beta_schedule
8 | from ...util import append_zero
9 |
10 |
11 | def generate_roughly_equally_spaced_steps(
12 | num_substeps: int, max_step: int
13 | ) -> np.ndarray:
14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
15 |
16 |
17 | class Discretization:
18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
19 | sigmas = self.get_sigmas(n, device=device)
20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas
21 | return sigmas if not flip else torch.flip(sigmas, (0,))
22 |
23 | @abstractmethod
24 | def get_sigmas(self, n, device):
25 | pass
26 |
27 |
28 | class EDMDiscretization(Discretization):
29 | def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
30 | self.sigma_min = sigma_min
31 | self.sigma_max = sigma_max
32 | self.rho = rho
33 |
34 | def get_sigmas(self, n, device="cpu"):
35 | ramp = torch.linspace(0, 1, n, device=device)
36 | min_inv_rho = self.sigma_min ** (1 / self.rho)
37 | max_inv_rho = self.sigma_max ** (1 / self.rho)
38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
39 | return sigmas
40 |
41 |
42 | class LegacyDDPMDiscretization(Discretization):
43 | def __init__(
44 | self,
45 | linear_start=0.00085,
46 | linear_end=0.0120,
47 | num_timesteps=1000,
48 | ):
49 | super().__init__()
50 | self.num_timesteps = num_timesteps
51 | betas = make_beta_schedule(
52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
53 | )
54 | alphas = 1.0 - betas
55 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
56 | self.to_torch = partial(torch.tensor, dtype=torch.float32)
57 |
58 | def get_sigmas(self, n, device="cpu"):
59 | if n < self.num_timesteps:
60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
61 | alphas_cumprod = self.alphas_cumprod[timesteps]
62 | elif n == self.num_timesteps:
63 | alphas_cumprod = self.alphas_cumprod
64 | else:
65 | raise ValueError
66 |
67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
69 | return torch.flip(sigmas, (0,))
70 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/guiders.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 |
5 | from ...util import default, instantiate_from_config
6 |
7 |
8 | class VanillaCFG:
9 | """
10 | implements parallelized CFG
11 | """
12 |
13 | def __init__(self, scale, dyn_thresh_config=None):
14 | scale_schedule = lambda scale, sigma: scale # independent of step
15 | self.scale_schedule = partial(scale_schedule, scale)
16 | self.dyn_thresh = instantiate_from_config(
17 | default(
18 | dyn_thresh_config,
19 | {
20 | "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
21 | },
22 | )
23 | )
24 |
25 | def __call__(self, x, sigma):
26 | x_u, x_c = x.chunk(2)
27 | scale_value = self.scale_schedule(sigma)
28 | x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29 | return x_pred
30 |
31 | def prepare_inputs(self, x, s, c, uc):
32 | c_out = dict()
33 |
34 | for k in c:
35 | if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']:
36 | c_out[k] = torch.cat((uc[k], c[k]), 0)
37 | else:
38 | assert c[k] == uc[k]
39 | c_out[k] = c[k]
40 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41 |
42 |
43 |
44 | class LinearCFG:
45 | def __init__(self, scale, scale_min=None, dyn_thresh_config=None):
46 | if scale_min is None:
47 | scale_min = scale
48 | scale_schedule = lambda scale, scale_min, sigma: (scale - scale_min) * sigma / 14.6146 + scale_min
49 | self.scale_schedule = partial(scale_schedule, scale, scale_min)
50 | self.dyn_thresh = instantiate_from_config(
51 | default(
52 | dyn_thresh_config,
53 | {
54 | "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
55 | },
56 | )
57 | )
58 |
59 | def __call__(self, x, sigma):
60 | x_u, x_c = x.chunk(2)
61 | scale_value = self.scale_schedule(sigma)
62 | x_pred = self.dyn_thresh(x_u, x_c, scale_value)
63 | return x_pred
64 |
65 | def prepare_inputs(self, x, s, c, uc):
66 | c_out = dict()
67 |
68 | for k in c:
69 | if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']:
70 | c_out[k] = torch.cat((uc[k], c[k]), 0)
71 | else:
72 | assert c[k] == uc[k]
73 | c_out[k] = c[k]
74 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out
75 |
76 |
77 |
78 | class IdentityGuider:
79 | def __call__(self, x, sigma):
80 | return x
81 |
82 | def prepare_inputs(self, x, s, c, uc):
83 | c_out = dict()
84 |
85 | for k in c:
86 | c_out[k] = c[k]
87 |
88 | return x, s, c_out
89 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/loss.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from omegaconf import ListConfig
6 |
7 | from ...util import append_dims, instantiate_from_config
8 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS
9 |
10 |
11 | class StandardDiffusionLoss(nn.Module):
12 | def __init__(
13 | self,
14 | sigma_sampler_config,
15 | type="l2",
16 | offset_noise_level=0.0,
17 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
18 | ):
19 | super().__init__()
20 |
21 | assert type in ["l2", "l1", "lpips"]
22 |
23 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
24 |
25 | self.type = type
26 | self.offset_noise_level = offset_noise_level
27 |
28 | if type == "lpips":
29 | self.lpips = LPIPS().eval()
30 |
31 | if not batch2model_keys:
32 | batch2model_keys = []
33 |
34 | if isinstance(batch2model_keys, str):
35 | batch2model_keys = [batch2model_keys]
36 |
37 | self.batch2model_keys = set(batch2model_keys)
38 |
39 | def __call__(self, network, denoiser, conditioner, input, batch):
40 | cond = conditioner(batch)
41 | additional_model_inputs = {
42 | key: batch[key] for key in self.batch2model_keys.intersection(batch)
43 | }
44 |
45 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
46 | noise = torch.randn_like(input)
47 | if self.offset_noise_level > 0.0:
48 | noise = noise + self.offset_noise_level * append_dims(
49 | torch.randn(input.shape[0], device=input.device), input.ndim
50 | )
51 | noised_input = input + noise * append_dims(sigmas, input.ndim)
52 | model_output = denoiser(
53 | network, noised_input, sigmas, cond, **additional_model_inputs
54 | )
55 | w = append_dims(denoiser.w(sigmas), input.ndim)
56 | return self.get_loss(model_output, input, w)
57 |
58 | def get_loss(self, model_output, target, w):
59 | if self.type == "l2":
60 | return torch.mean(
61 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
62 | )
63 | elif self.type == "l1":
64 | return torch.mean(
65 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
66 | )
67 | elif self.type == "lpips":
68 | loss = self.lpips(model_output, target).reshape(-1)
69 | return loss
70 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/sampling_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scipy import integrate
3 |
4 | from ...util import append_dims
5 |
6 |
7 | class NoDynamicThresholding:
8 | def __call__(self, uncond, cond, scale):
9 | return uncond + scale.view(-1, 1, 1, 1) * (cond - uncond)
10 |
11 |
12 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
13 | if order - 1 > i:
14 | raise ValueError(f"Order {order} too high for step {i}")
15 |
16 | def fn(tau):
17 | prod = 1.0
18 | for k in range(order):
19 | if j == k:
20 | continue
21 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
22 | return prod
23 |
24 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
25 |
26 |
27 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
28 | if not eta:
29 | return sigma_to, 0.0
30 | sigma_up = torch.minimum(
31 | sigma_to,
32 | eta
33 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
34 | )
35 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
36 | return sigma_down, sigma_up
37 |
38 |
39 | def to_d(x, sigma, denoised):
40 | return (x - denoised) / append_dims(sigma, x.ndim)
41 |
42 |
43 | def to_neg_log_sigma(sigma):
44 | return sigma.log().neg()
45 |
46 |
47 | def to_sigma(neg_log_sigma):
48 | return neg_log_sigma.neg().exp()
49 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/sigma_sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ...util import default, instantiate_from_config
4 |
5 |
6 | class EDMSampling:
7 | def __init__(self, p_mean=-1.2, p_std=1.2):
8 | self.p_mean = p_mean
9 | self.p_std = p_std
10 |
11 | def __call__(self, n_samples, rand=None):
12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13 | return log_sigma.exp()
14 |
15 |
16 | class DiscreteSampling:
17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, idx_range=None):
18 | self.num_idx = num_idx
19 | self.sigmas = instantiate_from_config(discretization_config)(
20 | num_idx, do_append_zero=do_append_zero, flip=flip
21 | )
22 | self.idx_range = idx_range
23 |
24 | def idx_to_sigma(self, idx):
25 | # print(self.sigmas[idx])
26 | return self.sigmas[idx]
27 |
28 | def __call__(self, n_samples, rand=None):
29 | if self.idx_range is None:
30 | idx = default(
31 | rand,
32 | torch.randint(0, self.num_idx, (n_samples,)),
33 | )
34 | else:
35 | idx = default(
36 | rand,
37 | torch.randint(self.idx_range[0], self.idx_range[1], (n_samples,)),
38 | )
39 | return self.idx_to_sigma(idx)
40 |
41 |
--------------------------------------------------------------------------------
/sgm/modules/diffusionmodules/wrappers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from packaging import version
4 | # import torch._dynamo
5 | # torch._dynamo.config.suppress_errors = True
6 | # torch._dynamo.config.cache_size_limit = 512
7 |
8 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
9 |
10 |
11 | class IdentityWrapper(nn.Module):
12 | def __init__(self, diffusion_model, compile_model: bool = False):
13 | super().__init__()
14 | compile = (
15 | torch.compile
16 | if (version.parse(torch.__version__) >= version.parse("2.0.0"))
17 | and compile_model
18 | else lambda x: x
19 | )
20 | self.diffusion_model = compile(diffusion_model)
21 |
22 | def forward(self, *args, **kwargs):
23 | return self.diffusion_model(*args, **kwargs)
24 |
25 |
26 | class OpenAIWrapper(IdentityWrapper):
27 | def forward(
28 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
29 | ) -> torch.Tensor:
30 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
31 | return self.diffusion_model(
32 | x,
33 | timesteps=t,
34 | context=c.get("crossattn", None),
35 | y=c.get("vector", None),
36 | **kwargs,
37 | )
38 |
39 |
40 | class OpenAIHalfWrapper(IdentityWrapper):
41 | def __init__(self, *args, **kwargs):
42 | super().__init__(*args, **kwargs)
43 | self.diffusion_model = self.diffusion_model.half()
44 |
45 | def forward(
46 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
47 | ) -> torch.Tensor:
48 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
49 | _context = c.get("crossattn", None)
50 | _y = c.get("vector", None)
51 | if _context is not None:
52 | _context = _context.half()
53 | if _y is not None:
54 | _y = _y.half()
55 | x = x.half()
56 | t = t.half()
57 |
58 | out = self.diffusion_model(
59 | x,
60 | timesteps=t,
61 | context=_context,
62 | y=_y,
63 | **kwargs,
64 | )
65 | return out.float()
66 |
67 |
68 | class ControlWrapper(nn.Module):
69 | def __init__(self, diffusion_model, compile_model: bool = False, dtype=torch.float32):
70 | super().__init__()
71 | self.compile = (
72 | torch.compile
73 | if (version.parse(torch.__version__) >= version.parse("2.0.0"))
74 | and compile_model
75 | else lambda x: x
76 | )
77 | self.diffusion_model = self.compile(diffusion_model)
78 | self.control_model = None
79 | self.dtype = dtype
80 |
81 | def load_control_model(self, control_model):
82 | self.control_model = self.compile(control_model)
83 |
84 | def forward(
85 | self, x: torch.Tensor, t: torch.Tensor, c: dict, control_scale=1, **kwargs
86 | ) -> torch.Tensor:
87 | with torch.autocast("cuda", dtype=self.dtype):
88 | control = self.control_model(x=c.get("control", None), timesteps=t, xt=x,
89 | control_vector=c.get("control_vector", None),
90 | mask_x=c.get("mask_x", None),
91 | context=c.get("crossattn", None),
92 | y=c.get("vector", None))
93 | out = self.diffusion_model(
94 | x,
95 | timesteps=t,
96 | context=c.get("crossattn", None),
97 | y=c.get("vector", None),
98 | control=control,
99 | control_scale=control_scale,
100 | **kwargs,
101 | )
102 | return out.float()
103 |
104 |
--------------------------------------------------------------------------------
/sgm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/sgm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self):
38 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
39 | device=self.parameters.device
40 | )
41 | return x
42 |
43 | def kl(self, other=None):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | else:
47 | if other is None:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50 | dim=[1, 2, 3],
51 | )
52 | else:
53 | return 0.5 * torch.sum(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var
56 | - 1.0
57 | - self.logvar
58 | + other.logvar,
59 | dim=[1, 2, 3],
60 | )
61 |
62 | def nll(self, sample, dims=[1, 2, 3]):
63 | if self.deterministic:
64 | return torch.Tensor([0.0])
65 | logtwopi = np.log(2.0 * np.pi)
66 | return 0.5 * torch.sum(
67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68 | dim=dims,
69 | )
70 |
71 | def mode(self):
72 | return self.mean
73 |
74 |
75 | def normal_kl(mean1, logvar1, mean2, logvar2):
76 | """
77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78 | Compute the KL divergence between two gaussians.
79 | Shapes are automatically broadcasted, so batches can be compared to
80 | scalars, among other use cases.
81 | """
82 | tensor = None
83 | for obj in (mean1, logvar1, mean2, logvar2):
84 | if isinstance(obj, torch.Tensor):
85 | tensor = obj
86 | break
87 | assert tensor is not None, "at least one argument must be a Tensor"
88 |
89 | # Force variances to be Tensors. Broadcasting helps convert scalars to
90 | # Tensors, but it does not work for torch.exp().
91 | logvar1, logvar2 = [
92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93 | for x in (logvar1, logvar2)
94 | ]
95 |
96 | return 0.5 * (
97 | -1.0
98 | + logvar2
99 | - logvar1
100 | + torch.exp(logvar1 - logvar2)
101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102 | )
103 |
--------------------------------------------------------------------------------
/sgm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates",
15 | torch.tensor(0, dtype=torch.int)
16 | if use_num_upates
17 | else torch.tensor(-1, dtype=torch.int),
18 | )
19 |
20 | for name, p in model.named_parameters():
21 | if p.requires_grad:
22 | # remove as '.'-character is not allowed in buffers
23 | s_name = name.replace(".", "")
24 | self.m_name2s_name.update({name: s_name})
25 | self.register_buffer(s_name, p.clone().detach().data)
26 |
27 | self.collected_params = []
28 |
29 | def reset_num_updates(self):
30 | del self.num_updates
31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32 |
33 | def forward(self, model):
34 | decay = self.decay
35 |
36 | if self.num_updates >= 0:
37 | self.num_updates += 1
38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39 |
40 | one_minus_decay = 1.0 - decay
41 |
42 | with torch.no_grad():
43 | m_param = dict(model.named_parameters())
44 | shadow_params = dict(self.named_buffers())
45 |
46 | for key in m_param:
47 | if m_param[key].requires_grad:
48 | sname = self.m_name2s_name[key]
49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50 | shadow_params[sname].sub_(
51 | one_minus_decay * (shadow_params[sname] - m_param[key])
52 | )
53 | else:
54 | assert not key in self.m_name2s_name
55 |
56 | def copy_to(self, model):
57 | m_param = dict(model.named_parameters())
58 | shadow_params = dict(self.named_buffers())
59 | for key in m_param:
60 | if m_param[key].requires_grad:
61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62 | else:
63 | assert not key in self.m_name2s_name
64 |
65 | def store(self, parameters):
66 | """
67 | Save the current parameters for restoring later.
68 | Args:
69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70 | temporarily stored.
71 | """
72 | self.collected_params = [param.clone() for param in parameters]
73 |
74 | def restore(self, parameters):
75 | """
76 | Restore the parameters stored with the `store` method.
77 | Useful to validate the model with EMA parameters without affecting the
78 | original optimization process. Store the parameters before the
79 | `copy_to` method. After validation (or model saving), use this to
80 | restore the former parameters.
81 | Args:
82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83 | updated with the stored parameters.
84 | """
85 | for c_param, param in zip(self.collected_params, parameters):
86 | param.data.copy_(c_param.data)
87 |
--------------------------------------------------------------------------------
/sgm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch.cuda
2 | import argparse
3 | from SUPIR.util import create_SUPIR_model, PIL2Tensor, Tensor2PIL, convert_dtype
4 | from PIL import Image
5 | from llava.llava_agent import LLavaAgent
6 | from CKPT_PTH import LLAVA_MODEL_PATH
7 | import os
8 | if torch.cuda.device_count() >= 2:
9 | SUPIR_device = 'cuda:0'
10 | LLaVA_device = 'cuda:1'
11 | elif torch.cuda.device_count() == 1:
12 | SUPIR_device = 'cuda:0'
13 | LLaVA_device = 'cuda:0'
14 | else:
15 | raise ValueError('Currently support CUDA only.')
16 |
17 | # hyparams here
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--img_dir", type=str)
20 | parser.add_argument("--save_dir", type=str)
21 | parser.add_argument("--upscale", type=int, default=1)
22 | parser.add_argument("--SUPIR_sign", type=str, default='Q', choices=['F', 'Q'])
23 | parser.add_argument("--seed", type=int, default=1234)
24 | parser.add_argument("--min_size", type=int, default=1024)
25 | parser.add_argument("--edm_steps", type=int, default=50)
26 | parser.add_argument("--s_stage1", type=int, default=-1)
27 | parser.add_argument("--s_churn", type=int, default=5)
28 | parser.add_argument("--s_noise", type=float, default=1.003)
29 | parser.add_argument("--s_cfg", type=float, default=7.5)
30 | parser.add_argument("--s_stage2", type=float, default=1.)
31 | parser.add_argument("--num_samples", type=int, default=1)
32 | parser.add_argument("--a_prompt", type=str,
33 | default='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
34 | 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
35 | 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
36 | 'hyper sharpness, perfect without deformations.')
37 | parser.add_argument("--n_prompt", type=str,
38 | default='painting, oil painting, illustration, drawing, art, sketch, oil painting, '
39 | 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, '
40 | 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
41 | 'deformed, lowres, over-smooth')
42 | parser.add_argument("--color_fix_type", type=str, default='Wavelet', choices=["None", "AdaIn", "Wavelet"])
43 | parser.add_argument("--linear_CFG", action='store_true', default=False)
44 | parser.add_argument("--linear_s_stage2", action='store_true', default=False)
45 | parser.add_argument("--spt_linear_CFG", type=float, default=1.0)
46 | parser.add_argument("--spt_linear_s_stage2", type=float, default=0.)
47 | parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16'])
48 | parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16'])
49 | parser.add_argument("--no_llava", action='store_true', default=False)
50 | args = parser.parse_args()
51 | print(args)
52 | use_llava = not args.no_llava
53 |
54 | # load SUPIR
55 | model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device)
56 | model.ae_dtype = convert_dtype(args.ae_dtype)
57 | model.model.dtype = convert_dtype(args.diff_dtype)
58 | # load LLaVA
59 | if use_llava:
60 | llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device)
61 | else:
62 | llava_agent = None
63 |
64 | os.makedirs(args.save_dir, exist_ok=True)
65 | for img_pth in os.listdir(args.img_dir):
66 | img_name = os.path.splitext(img_pth)[0]
67 |
68 | LQ_img = Image.open(os.path.join(args.img_dir, img_pth))
69 | LQ_img, h0, w0 = PIL2Tensor(LQ_img, upsacle=args.upscale, min_size=args.min_size)
70 | LQ_img = LQ_img.unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
71 |
72 | # step 1: Pre-denoise for LLaVA)
73 | clean_imgs = model.batchify_denoise(LQ_img)
74 | clean_PIL_img = Tensor2PIL(clean_imgs[0], h0, w0)
75 |
76 | # step 2: LLaVA
77 | if use_llava:
78 | captions = llava_agent.gen_image_caption([clean_PIL_img])
79 | else:
80 | captions = ['']
81 | print(captions)
82 |
83 | # # step 3: Diffusion Process
84 | samples = model.batchify_sample(LQ_img, captions, num_steps=args.edm_steps, restoration_scale=args.s_stage1, s_churn=args.s_churn,
85 | s_noise=args.s_noise, cfg_scale=args.s_cfg, control_scale=args.s_stage2, seed=args.seed,
86 | num_samples=args.num_samples, p_p=args.a_prompt, n_p=args.n_prompt, color_fix_type=args.color_fix_type,
87 | use_linear_CFG=args.linear_CFG, use_linear_control_scale=args.linear_s_stage2,
88 | cfg_scale_start=args.spt_linear_CFG, control_scale_start=args.spt_linear_s_stage2)
89 | # save
90 | for _i, sample in enumerate(samples):
91 | Tensor2PIL(sample, h0, w0).save(f'{args.save_dir}/{img_name}_{_i}.png')
92 |
93 |
--------------------------------------------------------------------------------