├── .idea ├── .gitignore ├── .name ├── SUPIR.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── webResources.xml ├── CKPT_PTH.py ├── LICENSE ├── README.md ├── SUPIR ├── __init__.py ├── models │ ├── SUPIR_model.py │ └── __init__.py ├── modules │ ├── SUPIR_v0.py │ └── __init__.py ├── util.py └── utils │ ├── __init__.py │ └── colorfix.py ├── assets ├── DemoGuide.png ├── framework.png └── teaser.png ├── cog.yaml ├── gradio_demo.py ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_pope.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── eval_textvqa.py │ ├── generate_webpage_data_from_table.py │ ├── m4c_evaluator.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_loader.py │ ├── model_vqa_mmbench.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── llava_agent.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ └── train_mem.py └── utils.py ├── options └── SUPIR_v0.yaml ├── predict.py ├── requirements.txt ├── sgm ├── __init__.py ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── autoencoder.py │ └── diffusion.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── autoencoding │ │ ├── __init__.py │ │ ├── losses │ │ │ └── __init__.py │ │ ├── lpips │ │ │ ├── __init__.py │ │ │ ├── loss │ │ │ │ ├── .gitignore │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── lpips.py │ │ │ ├── model │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── model.py │ │ │ ├── util.py │ │ │ └── vqperceptual.py │ │ └── regularizers │ │ │ └── __init__.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── denoiser.py │ │ ├── denoiser_scaling.py │ │ ├── denoiser_weighting.py │ │ ├── discretizer.py │ │ ├── guiders.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── sampling.py │ │ ├── sampling_utils.py │ │ ├── sigma_sampling.py │ │ ├── util.py │ │ └── wrappers.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ └── encoders │ │ ├── __init__.py │ │ └── modules.py └── util.py └── test.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/.name: -------------------------------------------------------------------------------- 1 | Diff4R -------------------------------------------------------------------------------- /.idea/SUPIR.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 23 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 38 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webResources.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /CKPT_PTH.py: -------------------------------------------------------------------------------- 1 | LLAVA_CLIP_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/clip-vit-large-patch14-336' 2 | LLAVA_MODEL_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/llava-v1.5-13b' 3 | SDXL_CLIP1_PATH = '/opt/data/private/AIGC_pretrain/clip-vit-large-patch14' 4 | SDXL_CLIP2_CACHE_DIR = '/opt/data/private/AIGC_pretrain' # put package 'models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k' here -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fanghua-Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SUPIR/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/SUPIR/__init__.py -------------------------------------------------------------------------------- /SUPIR/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/SUPIR/models/__init__.py -------------------------------------------------------------------------------- /SUPIR/modules/__init__.py: -------------------------------------------------------------------------------- 1 | SDXL_BASE_CHANNEL_DICT = { 2 | 'cond_output_channels': [320] * 4 + [640] * 3 + [1280] * 3, 3 | 'project_channels': [160] * 4 + [320] * 3 + [640] * 3, 4 | 'concat_channels': [320] * 2 + [640] * 3 + [1280] * 4 + [0] 5 | } 6 | 7 | SDXL_REFINE_CHANNEL_DICT = { 8 | 'cond_output_channels': [384] * 4 + [768] * 3 + [1536] * 6, 9 | 'project_channels': [192] * 4 + [384] * 3 + [768] * 6, 10 | 'concat_channels': [384] * 2 + [768] * 3 + [1536] * 7 + [0] 11 | } -------------------------------------------------------------------------------- /SUPIR/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | from torch.nn.functional import interpolate 7 | from omegaconf import OmegaConf 8 | from sgm.util import instantiate_from_config 9 | 10 | 11 | def get_state_dict(d): 12 | return d.get('state_dict', d) 13 | 14 | 15 | def load_state_dict(ckpt_path, location='cpu'): 16 | _, extension = os.path.splitext(ckpt_path) 17 | if extension.lower() == ".safetensors": 18 | import safetensors.torch 19 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 20 | else: 21 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 22 | state_dict = get_state_dict(state_dict) 23 | print(f'Loaded state_dict from [{ckpt_path}]') 24 | return state_dict 25 | 26 | 27 | def create_model(config_path): 28 | config = OmegaConf.load(config_path) 29 | model = instantiate_from_config(config.model).cpu() 30 | print(f'Loaded model config from [{config_path}]') 31 | return model 32 | 33 | 34 | def create_SUPIR_model(config_path, SUPIR_sign=None): 35 | config = OmegaConf.load(config_path) 36 | model = instantiate_from_config(config.model).cpu() 37 | print(f'Loaded model config from [{config_path}]') 38 | if config.SDXL_CKPT is not None: 39 | model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False) 40 | if config.SUPIR_CKPT is not None: 41 | model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False) 42 | if SUPIR_sign is not None: 43 | assert SUPIR_sign in ['F', 'Q'] 44 | if SUPIR_sign == 'F': 45 | model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False) 46 | elif SUPIR_sign == 'Q': 47 | model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False) 48 | return model 49 | 50 | def load_QF_ckpt(config_path): 51 | config = OmegaConf.load(config_path) 52 | ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location='cpu') 53 | ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location='cpu') 54 | return ckpt_Q, ckpt_F 55 | 56 | 57 | def PIL2Tensor(img, upsacle=1, min_size=1024): 58 | ''' 59 | PIL.Image -> Tensor[C, H, W], RGB, [-1, 1] 60 | ''' 61 | # size 62 | w, h = img.size 63 | w *= upsacle 64 | h *= upsacle 65 | w0, h0 = round(w), round(h) 66 | if min(w, h) < min_size: 67 | _upsacle = min_size / min(w, h) 68 | w *= _upsacle 69 | h *= _upsacle 70 | else: 71 | _upsacle = 1 72 | w = int(np.round(w / 64.0)) * 64 73 | h = int(np.round(h / 64.0)) * 64 74 | x = img.resize((w, h), Image.BICUBIC) 75 | x = np.array(x).round().clip(0, 255).astype(np.uint8) 76 | x = x / 255 * 2 - 1 77 | x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1) 78 | return x, h0, w0 79 | 80 | 81 | def Tensor2PIL(x, h0, w0): 82 | ''' 83 | Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image 84 | ''' 85 | x = x.unsqueeze(0) 86 | x = interpolate(x, size=(h0, w0), mode='bicubic') 87 | x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 88 | return Image.fromarray(x) 89 | 90 | 91 | def HWC3(x): 92 | assert x.dtype == np.uint8 93 | if x.ndim == 2: 94 | x = x[:, :, None] 95 | assert x.ndim == 3 96 | H, W, C = x.shape 97 | assert C == 1 or C == 3 or C == 4 98 | if C == 3: 99 | return x 100 | if C == 1: 101 | return np.concatenate([x, x, x], axis=2) 102 | if C == 4: 103 | color = x[:, :, 0:3].astype(np.float32) 104 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 105 | y = color * alpha + 255.0 * (1.0 - alpha) 106 | y = y.clip(0, 255).astype(np.uint8) 107 | return y 108 | 109 | 110 | def upscale_image(input_image, upscale, min_size=None, unit_resolution=64): 111 | H, W, C = input_image.shape 112 | H = float(H) 113 | W = float(W) 114 | H *= upscale 115 | W *= upscale 116 | if min_size is not None: 117 | if min(H, W) < min_size: 118 | _upsacle = min_size / min(W, H) 119 | W *= _upsacle 120 | H *= _upsacle 121 | H = int(np.round(H / unit_resolution)) * unit_resolution 122 | W = int(np.round(W / unit_resolution)) * unit_resolution 123 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA) 124 | img = img.round().clip(0, 255).astype(np.uint8) 125 | return img 126 | 127 | 128 | def fix_resize(input_image, size=512, unit_resolution=64): 129 | H, W, C = input_image.shape 130 | H = float(H) 131 | W = float(W) 132 | upscale = size / min(H, W) 133 | H *= upscale 134 | W *= upscale 135 | H = int(np.round(H / unit_resolution)) * unit_resolution 136 | W = int(np.round(W / unit_resolution)) * unit_resolution 137 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA) 138 | img = img.round().clip(0, 255).astype(np.uint8) 139 | return img 140 | 141 | 142 | 143 | def Numpy2Tensor(img): 144 | ''' 145 | np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1] 146 | ''' 147 | # size 148 | img = np.array(img) / 255 * 2 - 1 149 | img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) 150 | return img 151 | 152 | 153 | def Tensor2Numpy(x, h0=None, w0=None): 154 | ''' 155 | Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image 156 | ''' 157 | if h0 is not None and w0 is not None: 158 | x = x.unsqueeze(0) 159 | x = interpolate(x, size=(h0, w0), mode='bicubic') 160 | x = x.squeeze(0) 161 | x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 162 | return x 163 | 164 | 165 | def convert_dtype(dtype_str): 166 | if dtype_str == 'fp32': 167 | return torch.float32 168 | elif dtype_str == 'fp16': 169 | return torch.float16 170 | elif dtype_str == 'bf16': 171 | return torch.bfloat16 172 | else: 173 | raise NotImplementedError 174 | -------------------------------------------------------------------------------- /SUPIR/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/SUPIR/utils/__init__.py -------------------------------------------------------------------------------- /SUPIR/utils/colorfix.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # -------------------------------------------------------------------------------- 3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) 4 | # -------------------------------------------------------------------------------- 5 | ''' 6 | 7 | import torch 8 | from PIL import Image 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | 12 | from torchvision.transforms import ToTensor, ToPILImage 13 | 14 | def adain_color_fix(target: Image, source: Image): 15 | # Convert images to tensors 16 | to_tensor = ToTensor() 17 | target_tensor = to_tensor(target).unsqueeze(0) 18 | source_tensor = to_tensor(source).unsqueeze(0) 19 | 20 | # Apply adaptive instance normalization 21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) 22 | 23 | # Convert tensor back to image 24 | to_image = ToPILImage() 25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 26 | 27 | return result_image 28 | 29 | def wavelet_color_fix(target: Image, source: Image): 30 | # Convert images to tensors 31 | to_tensor = ToTensor() 32 | target_tensor = to_tensor(target).unsqueeze(0) 33 | source_tensor = to_tensor(source).unsqueeze(0) 34 | 35 | # Apply wavelet reconstruction 36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor) 37 | 38 | # Convert tensor back to image 39 | to_image = ToPILImage() 40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 41 | 42 | return result_image 43 | 44 | def calc_mean_std(feat: Tensor, eps=1e-5): 45 | """Calculate mean and std for adaptive_instance_normalization. 46 | Args: 47 | feat (Tensor): 4D tensor. 48 | eps (float): A small value added to the variance to avoid 49 | divide-by-zero. Default: 1e-5. 50 | """ 51 | size = feat.size() 52 | assert len(size) == 4, 'The input feature should be 4D tensor.' 53 | b, c = size[:2] 54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps 55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1) 56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) 57 | return feat_mean, feat_std 58 | 59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): 60 | """Adaptive instance normalization. 61 | Adjust the reference features to have the similar color and illuminations 62 | as those in the degradate features. 63 | Args: 64 | content_feat (Tensor): The reference feature. 65 | style_feat (Tensor): The degradate features. 66 | """ 67 | size = content_feat.size() 68 | style_mean, style_std = calc_mean_std(style_feat) 69 | content_mean, content_std = calc_mean_std(content_feat) 70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 72 | 73 | def wavelet_blur(image: Tensor, radius: int): 74 | """ 75 | Apply wavelet blur to the input tensor. 76 | """ 77 | # input shape: (1, 3, H, W) 78 | # convolution kernel 79 | kernel_vals = [ 80 | [0.0625, 0.125, 0.0625], 81 | [0.125, 0.25, 0.125], 82 | [0.0625, 0.125, 0.0625], 83 | ] 84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 85 | # add channel dimensions to the kernel to make it a 4D tensor 86 | kernel = kernel[None, None] 87 | # repeat the kernel across all input channels 88 | kernel = kernel.repeat(3, 1, 1, 1) 89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 90 | # apply convolution 91 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 92 | return output 93 | 94 | def wavelet_decomposition(image: Tensor, levels=5): 95 | """ 96 | Apply wavelet decomposition to the input tensor. 97 | This function only returns the low frequency & the high frequency. 98 | """ 99 | high_freq = torch.zeros_like(image) 100 | for i in range(levels): 101 | radius = 2 ** i 102 | low_freq = wavelet_blur(image, radius) 103 | high_freq += (image - low_freq) 104 | image = low_freq 105 | 106 | return high_freq, low_freq 107 | 108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 109 | """ 110 | Apply wavelet decomposition, so that the content will have the same color as the style. 111 | """ 112 | # calculate the wavelet decomposition of the content feature 113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 114 | del content_low_freq 115 | # calculate the wavelet decomposition of the style feature 116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 117 | del style_high_freq 118 | # reconstruct the content feature with the style's high frequency 119 | return content_high_freq + style_low_freq 120 | 121 | -------------------------------------------------------------------------------- /assets/DemoGuide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/assets/DemoGuide.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/assets/framework.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/assets/teaser.png -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | system_packages: 7 | - "libgl1-mesa-glx" 8 | - "libglib2.0-0" 9 | python_version: "3.11" 10 | python_packages: 11 | - sentencepiece==0.1.98 12 | - tokenizers==0.13.3 13 | - torch>=2.1.0 14 | - torchvision>=0.16.0 15 | - uvicorn==0.21.1 16 | - transformers==4.28.1 17 | - accelerate==0.18.0 18 | - scikit-learn==1.2.2 19 | - sentencepiece==0.1.98 20 | - einops==0.7.0 21 | - einops-exts==0.0.4 22 | - timm==0.9.8 23 | - openai-clip==1.0.1 24 | - kornia==0.6.9 25 | - matplotlib==3.7.1 26 | - ninja==1.11.1 27 | - omegaconf==2.3.0 28 | - open-clip-torch==2.17.1 29 | - opencv-python==4.7.0.72 30 | - pandas==2.0.1 31 | - Pillow==9.4.0 32 | - pytorch-lightning==2.1.2 33 | - PyYAML==6.0 34 | - scipy==1.12.0 35 | - tqdm==4.65.0 36 | - triton==2.1.0 37 | - webdataset==0.2.48 38 | - xformers>=0.0.20 39 | run: 40 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 41 | predict: "predict.py:Predictor" 42 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | meta_pth = '/opt/data/private/metas/unsplash_ISO300-_PIL_1024_x2x4_APEX.txt' 37 | img_pths = [] 38 | with open(meta_pth, 'r') as f: 39 | for line in f.readlines(): 40 | img_pths.append(line.split('\t')[0]) 41 | f.close() 42 | 43 | img_pths = get_chunk(img_pths, args.num_chunks, args.chunk_idx) 44 | 45 | # split to batch 8 46 | img_pths = split_list(img_pths, 8) 47 | 48 | 49 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 50 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 51 | answers_file = os.path.expanduser(args.answers_file) 52 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 53 | ans_file = open(answers_file, "w") 54 | for line in tqdm(questions): 55 | idx = line["question_id"] 56 | image_file = line["image"] 57 | qs = line["text"] 58 | cur_prompt = qs 59 | if model.config.mm_use_im_start_end: 60 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 61 | else: 62 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 63 | 64 | conv = conv_templates[args.conv_mode].copy() 65 | conv.append_message(conv.roles[0], qs) 66 | conv.append_message(conv.roles[1], None) 67 | prompt = conv.get_prompt() 68 | 69 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 70 | 71 | image = Image.open(os.path.join(args.image_folder, image_file)) 72 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 73 | 74 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 75 | keywords = [stop_str] 76 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 77 | 78 | with torch.inference_mode(): 79 | output_ids = model.generate( 80 | input_ids, 81 | images=image_tensor.unsqueeze(0).half().cuda(), 82 | do_sample=True if args.temperature > 0 else False, 83 | temperature=args.temperature, 84 | top_p=args.top_p, 85 | num_beams=args.num_beams, 86 | # no_repeat_ngram_size=3, 87 | max_new_tokens=1024, 88 | use_cache=True) 89 | 90 | input_token_len = input_ids.shape[1] 91 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 92 | if n_diff_input_output > 0: 93 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 94 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 95 | outputs = outputs.strip() 96 | if outputs.endswith(stop_str): 97 | outputs = outputs[:-len(stop_str)] 98 | outputs = outputs.strip() 99 | 100 | ans_id = shortuuid.uuid() 101 | ans_file.write(json.dumps({"question_id": idx, 102 | "prompt": cur_prompt, 103 | "text": outputs, 104 | "answer_id": ans_id, 105 | "model_id": model_name, 106 | "metadata": {}}) + "\n") 107 | ans_file.flush() 108 | ans_file.close() 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 113 | parser.add_argument("--model-base", type=str, default=None) 114 | parser.add_argument("--image-folder", type=str, default="") 115 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 116 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 117 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 118 | parser.add_argument("--num-chunks", type=int, default=1) 119 | parser.add_argument("--chunk-idx", type=int, default=0) 120 | parser.add_argument("--temperature", type=float, default=0.2) 121 | parser.add_argument("--top_p", type=float, default=None) 122 | parser.add_argument("--num_beams", type=int, default=1) 123 | args = parser.parse_args() 124 | 125 | eval_model(args) 126 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | 16 | 17 | def load_image(image_file): 18 | if image_file.startswith('http') or image_file.startswith('https'): 19 | response = requests.get(image_file) 20 | image = Image.open(BytesIO(response.content)).convert('RGB') 21 | else: 22 | image = Image.open(image_file).convert('RGB') 23 | return image 24 | 25 | 26 | def eval_model(args): 27 | # Model 28 | disable_torch_init() 29 | 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) 32 | 33 | qs = args.query 34 | if model.config.mm_use_im_start_end: 35 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 36 | else: 37 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 38 | 39 | if 'llama-2' in model_name.lower(): 40 | conv_mode = "llava_llama_2" 41 | elif "v1" in model_name.lower(): 42 | conv_mode = "llava_v1" 43 | elif "mpt" in model_name.lower(): 44 | conv_mode = "mpt" 45 | else: 46 | conv_mode = "llava_v0" 47 | 48 | if args.conv_mode is not None and conv_mode != args.conv_mode: 49 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 50 | else: 51 | args.conv_mode = conv_mode 52 | 53 | conv = conv_templates[args.conv_mode].copy() 54 | conv.append_message(conv.roles[0], qs) 55 | conv.append_message(conv.roles[1], None) 56 | prompt = conv.get_prompt() 57 | 58 | image = load_image(args.image_file) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 60 | 61 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 62 | 63 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 64 | keywords = [stop_str] 65 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 66 | 67 | with torch.inference_mode(): 68 | output_ids = model.generate( 69 | input_ids, 70 | images=image_tensor, 71 | do_sample=True, 72 | temperature=0.2, 73 | max_new_tokens=1024, 74 | use_cache=True, 75 | stopping_criteria=[stopping_criteria]) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | print(outputs) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 91 | parser.add_argument("--model-base", type=str, default=None) 92 | parser.add_argument("--image-file", type=str, required=True) 93 | parser.add_argument("--query", type=str, required=True) 94 | parser.add_argument("--conv-mode", type=str, default=None) 95 | args = parser.parse_args() 96 | 97 | eval_model(args) 98 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/llava_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | 6 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 7 | from llava.conversation import conv_templates, SeparatorStyle 8 | from llava.model.builder import load_pretrained_model 9 | from llava.utils import disable_torch_init 10 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 11 | 12 | from PIL import Image 13 | import math 14 | import time 15 | import glob as gb 16 | 17 | 18 | class LLavaAgent: 19 | def __init__(self, model_path, device='cuda', conv_mode='vicuna_v1'): 20 | self.device = device 21 | if torch.device(self.device).index is not None: 22 | device_map = {'model': torch.device(self.device).index, 'lm_head': torch.device(self.device).index} 23 | else: 24 | device_map = 'auto' 25 | model_path = os.path.expanduser(model_path) 26 | model_name = get_model_name_from_path(model_path) 27 | tokenizer, model, image_processor, context_len = load_pretrained_model( 28 | model_path, None, model_name, device=self.device, device_map=device_map) 29 | self.model = model 30 | self.image_processor = image_processor 31 | self.tokenizer = tokenizer 32 | self.context_len = context_len 33 | self.qs = 'Describe this image and its style in a very detailed manner.' 34 | self.conv_mode = conv_mode 35 | 36 | if self.model.config.mm_use_im_start_end: 37 | self.qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + self.qs 38 | else: 39 | self.qs = DEFAULT_IMAGE_TOKEN + '\n' + self.qs 40 | 41 | self.conv = conv_templates[self.conv_mode].copy() 42 | self.conv.append_message(self.conv.roles[0], self.qs) 43 | self.conv.append_message(self.conv.roles[1], None) 44 | prompt = self.conv.get_prompt() 45 | self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze( 46 | 0).to(self.device) 47 | 48 | def update_qs(self, qs=None): 49 | if qs is None: 50 | qs = self.qs 51 | else: 52 | if self.model.config.mm_use_im_start_end: 53 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 54 | else: 55 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 56 | 57 | self.conv = conv_templates[self.conv_mode].copy() 58 | self.conv.append_message(self.conv.roles[0], qs) 59 | self.conv.append_message(self.conv.roles[1], None) 60 | prompt = self.conv.get_prompt() 61 | self.input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze( 62 | 0).to(self.device) 63 | 64 | def gen_image_caption(self, imgs, temperature=0.2, top_p=0.7, num_beams=1, qs=None): 65 | ''' 66 | [PIL.Image, ...] 67 | ''' 68 | self.update_qs(qs) 69 | 70 | bs = len(imgs) 71 | input_ids = self.input_ids.repeat(bs, 1) 72 | img_tensor_list = [] 73 | for image in imgs: 74 | _image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 75 | img_tensor_list.append(_image_tensor) 76 | image_tensor = torch.stack(img_tensor_list, dim=0).half().to(self.device) 77 | stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 78 | 79 | with torch.inference_mode(): 80 | output_ids = self.model.generate( 81 | input_ids, 82 | images=image_tensor, 83 | do_sample=True if temperature > 0 else False, 84 | temperature=temperature, 85 | top_p=top_p, 86 | num_beams=num_beams, 87 | # no_repeat_ngram_size=3, 88 | max_new_tokens=512, 89 | use_cache=True) 90 | 91 | input_token_len = input_ids.shape[1] 92 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True) 93 | 94 | img_captions = [] 95 | for output in outputs: 96 | output = output.strip() 97 | if output.endswith(stop_str): 98 | output = output[:-len(stop_str)] 99 | output = output.strip().replace('\n', ' ').replace('\r', ' ') 100 | img_captions.append(output) 101 | return img_captions 102 | 103 | 104 | if __name__ == '__main__': 105 | llava_agent = LLavaAgent("/opt/data/private/AIGC_pretrain/LLaVA1.5/llava-v1.5-13b") 106 | img = [Image.open('/opt/data/private/LV_Dataset/DiffGLV-Test-All/RealPhoto60/LQ/02.png')] 107 | caption = llava_agent.gen_image_caption(img) 108 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 35 | new_images.append(image) 36 | else: 37 | return image_processor(images, return_tensors='pt')['pixel_values'] 38 | if all(x.shape == new_images[0].shape for x in new_images): 39 | new_images = torch.stack(new_images, dim=0) 40 | return new_images 41 | 42 | 43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 45 | 46 | def insert_separator(X, sep): 47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 48 | 49 | input_ids = [] 50 | offset = 0 51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 52 | offset = 1 53 | input_ids.append(prompt_chunks[0][0]) 54 | 55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 56 | input_ids.extend(x[offset:]) 57 | 58 | if return_tensors is not None: 59 | if return_tensors == 'pt': 60 | return torch.tensor(input_ids, dtype=torch.long) 61 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 62 | return input_ids 63 | 64 | 65 | def get_model_name_from_path(model_path): 66 | model_path = model_path.strip("/") 67 | model_paths = model_path.split("/") 68 | if model_paths[-1].startswith('checkpoint-'): 69 | return model_paths[-2] + "_" + model_paths[-1] 70 | else: 71 | return model_paths[-1] 72 | 73 | 74 | 75 | 76 | class KeywordsStoppingCriteria(StoppingCriteria): 77 | def __init__(self, keywords, tokenizer, input_ids): 78 | self.keywords = keywords 79 | self.keyword_ids = [] 80 | self.max_keyword_len = 0 81 | for keyword in keywords: 82 | cur_keyword_ids = tokenizer(keyword).input_ids 83 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 84 | cur_keyword_ids = cur_keyword_ids[1:] 85 | if len(cur_keyword_ids) > self.max_keyword_len: 86 | self.max_keyword_len = len(cur_keyword_ids) 87 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 88 | self.tokenizer = tokenizer 89 | self.start_len = input_ids.shape[1] 90 | 91 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 92 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 93 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 94 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 95 | for keyword_id in self.keyword_ids: 96 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 97 | return True 98 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 99 | for keyword in self.keywords: 100 | if keyword in outputs: 101 | return True 102 | return False -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | images: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 76 | 77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 78 | outputs = self.model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | inputs_embeds=inputs_embeds, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | hidden_states = outputs[0] 90 | logits = self.lm_head(hidden_states) 91 | 92 | loss = None 93 | if labels is not None: 94 | # Shift so that tokens < n predict n 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_labels = labels[..., 1:].contiguous() 97 | # Flatten the tokens 98 | loss_fct = CrossEntropyLoss() 99 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 100 | shift_labels = shift_labels.view(-1) 101 | # Enable model/pipeline parallelism 102 | shift_labels = shift_labels.to(shift_logits.device) 103 | loss = loss_fct(shift_logits, shift_labels) 104 | 105 | if not return_dict: 106 | output = (logits,) + outputs[1:] 107 | return (loss,) + output if loss is not None else output 108 | 109 | return CausalLMOutputWithPast( 110 | loss=loss, 111 | logits=logits, 112 | past_key_values=outputs.past_key_values, 113 | hidden_states=outputs.hidden_states, 114 | attentions=outputs.attentions, 115 | ) 116 | 117 | def prepare_inputs_for_generation( 118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 119 | ): 120 | if past_key_values: 121 | input_ids = input_ids[:, -1:] 122 | 123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 124 | if inputs_embeds is not None and past_key_values is None: 125 | model_inputs = {"inputs_embeds": inputs_embeds} 126 | else: 127 | model_inputs = {"input_ids": input_ids} 128 | 129 | model_inputs.update( 130 | { 131 | "past_key_values": past_key_values, 132 | "use_cache": kwargs.get("use_cache"), 133 | "attention_mask": attention_mask, 134 | "images": kwargs.get("images", None), 135 | } 136 | ) 137 | return model_inputs 138 | 139 | AutoConfig.register("llava", LlavaConfig) 140 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 141 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | from CKPT_PTH import LLAVA_CLIP_PATH 6 | 7 | 8 | class CLIPVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower 15 | print(f'Loading vision tower: {self.vision_tower_name}') 16 | self.select_layer = args.mm_vision_select_layer 17 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 18 | 19 | if not delay_load: 20 | self.load_model() 21 | else: 22 | # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | self.cfg_only = CLIPVisionConfig.from_pretrained(LLAVA_CLIP_PATH) 24 | 25 | def load_model(self): 26 | self.image_processor = CLIPImageProcessor.from_pretrained(LLAVA_CLIP_PATH) 27 | self.vision_tower = CLIPVisionModel.from_pretrained(LLAVA_CLIP_PATH) 28 | self.vision_tower.requires_grad_(False) 29 | 30 | self.is_loaded = True 31 | 32 | def feature_select(self, image_forward_outs): 33 | image_features = image_forward_outs.hidden_states[self.select_layer] 34 | if self.select_feature == 'patch': 35 | image_features = image_features[:, 1:] 36 | elif self.select_feature == 'cls_patch': 37 | image_features = image_features 38 | else: 39 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 40 | return image_features 41 | 42 | @torch.no_grad() 43 | def forward(self, images): 44 | if type(images) is list: 45 | image_features = [] 46 | for image in images: 47 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 48 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 49 | image_features.append(image_feature) 50 | else: 51 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 52 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 53 | 54 | return image_features 55 | 56 | @property 57 | def dummy_feature(self): 58 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 59 | 60 | @property 61 | def dtype(self): 62 | return self.vision_tower.dtype 63 | 64 | @property 65 | def device(self): 66 | return self.vision_tower.device 67 | 68 | @property 69 | def config(self): 70 | if self.is_loaded: 71 | return self.vision_tower.config 72 | else: 73 | return self.cfg_only 74 | 75 | @property 76 | def hidden_size(self): 77 | return self.config.hidden_size 78 | 79 | @property 80 | def num_patches(self): 81 | return (self.config.image_size // self.config.patch_size) ** 2 82 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if 'llama-2' in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ('user', 'assistant') 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | # Similar operation in model_worker.py 56 | image_tensor = process_images([image], image_processor, args) 57 | if type(image_tensor) is list: 58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 59 | else: 60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 61 | 62 | while True: 63 | try: 64 | inp = input(f"{roles[0]}: ") 65 | except EOFError: 66 | inp = "" 67 | if not inp: 68 | print("exit...") 69 | break 70 | 71 | print(f"{roles[1]}: ", end="") 72 | 73 | if image is not None: 74 | # first message 75 | if model.config.mm_use_im_start_end: 76 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 77 | else: 78 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 79 | conv.append_message(conv.roles[0], inp) 80 | image = None 81 | else: 82 | # later messages 83 | conv.append_message(conv.roles[0], inp) 84 | conv.append_message(conv.roles[1], None) 85 | prompt = conv.get_prompt() 86 | 87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 89 | keywords = [stop_str] 90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 92 | 93 | with torch.inference_mode(): 94 | output_ids = model.generate( 95 | input_ids, 96 | images=image_tensor, 97 | do_sample=True, 98 | temperature=args.temperature, 99 | max_new_tokens=args.max_new_tokens, 100 | streamer=streamer, 101 | use_cache=True, 102 | stopping_criteria=[stopping_criteria]) 103 | 104 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 105 | conv.messages[-1][-1] = outputs 106 | 107 | if args.debug: 108 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 114 | parser.add_argument("--model-base", type=str, default=None) 115 | parser.add_argument("--image-file", type=str, required=True) 116 | parser.add_argument("--device", type=str, default="cuda") 117 | parser.add_argument("--conv-mode", type=str, default=None) 118 | parser.add_argument("--temperature", type=float, default=0.2) 119 | parser.add_argument("--max-new-tokens", type=int, default=512) 120 | parser.add_argument("--load-8bit", action="store_true") 121 | parser.add_argument("--load-4bit", action="store_true") 122 | parser.add_argument("--debug", action="store_true") 123 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 124 | args = parser.parse_args() 125 | main(args) 126 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /options/SUPIR_v0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: SUPIR.models.SUPIR_model.SUPIRModel 3 | params: 4 | ae_dtype: bf16 5 | diffusion_dtype: fp16 6 | scale_factor: 0.13025 7 | disable_first_stage_autocast: True 8 | network_wrapper: sgm.modules.diffusionmodules.wrappers.ControlWrapper 9 | 10 | denoiser_config: 11 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiserWithControl 12 | params: 13 | num_idx: 1000 14 | weighting_config: 15 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 16 | scaling_config: 17 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 18 | discretization_config: 19 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 20 | 21 | control_stage_config: 22 | target: SUPIR.modules.SUPIR_v0.GLVControl 23 | params: 24 | adm_in_channels: 2816 25 | num_classes: sequential 26 | use_checkpoint: True 27 | in_channels: 4 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: [4, 2] 31 | num_res_blocks: 2 32 | channel_mult: [1, 2, 4] 33 | num_head_channels: 64 34 | use_spatial_transformer: True 35 | use_linear_in_transformer: True 36 | transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 37 | # transformer_depth: [1, 1, 4] 38 | context_dim: 2048 39 | spatial_transformer_attn_type: softmax-xformers 40 | legacy: False 41 | input_upscale: 1 42 | 43 | network_config: 44 | target: SUPIR.modules.SUPIR_v0.LightGLVUNet 45 | params: 46 | mode: XL-base 47 | project_type: ZeroSFT 48 | project_channel_scale: 2 49 | adm_in_channels: 2816 50 | num_classes: sequential 51 | use_checkpoint: True 52 | in_channels: 4 53 | out_channels: 4 54 | model_channels: 320 55 | attention_resolutions: [4, 2] 56 | num_res_blocks: 2 57 | channel_mult: [1, 2, 4] 58 | num_head_channels: 64 59 | use_spatial_transformer: True 60 | use_linear_in_transformer: True 61 | transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 62 | context_dim: 2048 63 | spatial_transformer_attn_type: softmax-xformers 64 | legacy: False 65 | 66 | conditioner_config: 67 | target: sgm.modules.GeneralConditionerWithControl 68 | params: 69 | emb_models: 70 | # crossattn cond 71 | - is_trainable: False 72 | input_key: txt 73 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 74 | params: 75 | layer: hidden 76 | layer_idx: 11 77 | # crossattn and vector cond 78 | - is_trainable: False 79 | input_key: txt 80 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 81 | params: 82 | arch: ViT-bigG-14 83 | version: laion2b_s39b_b160k 84 | freeze: True 85 | layer: penultimate 86 | always_return_pooled: True 87 | legacy: False 88 | # vector cond 89 | - is_trainable: False 90 | input_key: original_size_as_tuple 91 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 92 | params: 93 | outdim: 256 # multiplied by two 94 | # vector cond 95 | - is_trainable: False 96 | input_key: crop_coords_top_left 97 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 98 | params: 99 | outdim: 256 # multiplied by two 100 | # vector cond 101 | - is_trainable: False 102 | input_key: target_size_as_tuple 103 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 104 | params: 105 | outdim: 256 # multiplied by two 106 | 107 | first_stage_config: 108 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper 109 | params: 110 | ckpt_path: ~ 111 | embed_dim: 4 112 | monitor: val/rec_loss 113 | ddconfig: 114 | attn_type: vanilla-xformers 115 | double_z: true 116 | z_channels: 4 117 | resolution: 256 118 | in_channels: 3 119 | out_ch: 3 120 | ch: 128 121 | ch_mult: [ 1, 2, 4, 4 ] 122 | num_res_blocks: 2 123 | attn_resolutions: [ ] 124 | dropout: 0.0 125 | lossconfig: 126 | target: torch.nn.Identity 127 | 128 | sampler_config: 129 | target: sgm.modules.diffusionmodules.sampling.RestoreEDMSampler 130 | params: 131 | num_steps: 100 132 | restore_cfg: 4.0 133 | s_churn: 0 134 | s_noise: 1.003 135 | discretization_config: 136 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 137 | guider_config: 138 | target: sgm.modules.diffusionmodules.guiders.LinearCFG 139 | params: 140 | scale: 7.5 141 | scale_min: 4.0 142 | 143 | p_p: 144 | 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, 145 | hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, 146 | skin pore detailing, hyper sharpness, perfect without deformations.' 147 | n_p: 148 | 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, 149 | unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, 150 | jpeg artifacts, deformed, lowres, over-smooth' 151 | 152 | SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors 153 | SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt 154 | SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt 155 | SUPIR_CKPT: ~ 156 | 157 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.95.1 2 | gradio==4.16.0 3 | gradio_imageslider==0.0.17 4 | Markdown==3.4.1 5 | numpy==1.24.2 6 | requests==2.28.2 7 | sentencepiece==0.1.98 8 | tokenizers==0.13.3 9 | torch>=2.1.0 10 | torchvision>=0.16.0 11 | uvicorn==0.21.1 12 | wandb==0.14.0 13 | httpx==0.24.0 14 | transformers==4.28.1 15 | accelerate==0.18.0 16 | scikit-learn==1.2.2 17 | sentencepiece==0.1.98 18 | einops==0.7.0 19 | einops-exts==0.0.4 20 | timm==0.9.8 21 | gradio_client==0.1.3 22 | openai-clip==1.0.1 23 | fsspec==2023.4.0 24 | kornia==0.6.9 25 | matplotlib==3.7.1 26 | ninja==1.11.1 27 | omegaconf==2.3.0 28 | open-clip-torch==2.17.1 29 | opencv-python==4.7.0.72 30 | pandas==2.0.1 31 | Pillow==9.4.0 32 | pytorch-lightning==2.1.2 33 | PyYAML==6.0 34 | scipy==1.9.1 35 | tqdm==4.65.0 36 | triton==2.1.0 37 | urllib3==1.26.15 38 | webdataset==0.2.48 39 | xformers>=0.0.20 40 | -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | from .encoders.modules import GeneralConditionerWithControl 3 | from .encoders.modules import PreparedConditioner 4 | 5 | UNCONDITIONAL_CONFIG = { 6 | "target": "sgm.modules.GeneralConditioner", 7 | "params": {"emb_models": []}, 8 | } 9 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 53 | outs1[kk] 54 | ) 55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | def __init__(self): 69 | super(ScalingLayer, self).__init__() 70 | self.register_buffer( 71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 75 | ) 76 | 77 | def forward(self, inp): 78 | return (inp - self.shift) / self.scale 79 | 80 | 81 | class NetLinLayer(nn.Module): 82 | """A single linear layer which does a 1x1 conv""" 83 | 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = ( 87 | [ 88 | nn.Dropout(), 89 | ] 90 | if (use_dropout) 91 | else [] 92 | ) 93 | layers += [ 94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 95 | ] 96 | self.model = nn.Sequential(*layers) 97 | 98 | 99 | class vgg16(torch.nn.Module): 100 | def __init__(self, requires_grad=False, pretrained=True): 101 | super(vgg16, self).__init__() 102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 103 | self.slice1 = torch.nn.Sequential() 104 | self.slice2 = torch.nn.Sequential() 105 | self.slice3 = torch.nn.Sequential() 106 | self.slice4 = torch.nn.Sequential() 107 | self.slice5 = torch.nn.Sequential() 108 | self.N_slices = 5 109 | for x in range(4): 110 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(4, 9): 112 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(9, 16): 114 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(16, 23): 116 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(23, 30): 118 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 119 | if not requires_grad: 120 | for param in self.parameters(): 121 | param.requires_grad = False 122 | 123 | def forward(self, X): 124 | h = self.slice1(X) 125 | h_relu1_2 = h 126 | h = self.slice2(h) 127 | h_relu2_2 = h 128 | h = self.slice3(h) 129 | h_relu3_3 = h 130 | h = self.slice4(h) 131 | h_relu4_3 = h 132 | h = self.slice5(h) 133 | h_relu5_3 = h 134 | vgg_outputs = namedtuple( 135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 136 | ) 137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 138 | return out 139 | 140 | 141 | def normalize_tensor(x, eps=1e-10): 142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 143 | return x / (norm_factor + eps) 144 | 145 | 146 | def spatial_average(x, keepdim=True): 147 | return x.mean([2, 3], keepdim=keepdim) 148 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import DiagonalGaussianDistribution 9 | 10 | 11 | class AbstractRegularizer(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 16 | raise NotImplementedError() 17 | 18 | @abstractmethod 19 | def get_trainable_parameters(self) -> Any: 20 | raise NotImplementedError() 21 | 22 | 23 | class DiagonalGaussianRegularizer(AbstractRegularizer): 24 | def __init__(self, sample: bool = True): 25 | super().__init__() 26 | self.sample = sample 27 | 28 | def get_trainable_parameters(self) -> Any: 29 | yield from () 30 | 31 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 32 | log = dict() 33 | posterior = DiagonalGaussianDistribution(z) 34 | if self.sample: 35 | z = posterior.sample() 36 | else: 37 | z = posterior.mode() 38 | kl_loss = posterior.kl() 39 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 40 | log["kl_loss"] = kl_loss 41 | return z, log 42 | 43 | 44 | def measure_perplexity(predicted_indices, num_centroids): 45 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 46 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 47 | encodings = ( 48 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 49 | ) 50 | avg_probs = encodings.mean(0) 51 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 52 | cluster_use = torch.sum(avg_probs > 0) 53 | return perplexity, cluster_use 54 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser 2 | from .discretizer import Discretization 3 | from .loss import StandardDiffusionLoss 4 | from .model import Decoder, Encoder, Model 5 | from .openaimodel import UNetModel 6 | from .sampling import BaseDiffusionSampler 7 | from .wrappers import OpenAIWrapper 8 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ...util import append_dims, instantiate_from_config 4 | 5 | 6 | class Denoiser(nn.Module): 7 | def __init__(self, weighting_config, scaling_config): 8 | super().__init__() 9 | 10 | self.weighting = instantiate_from_config(weighting_config) 11 | self.scaling = instantiate_from_config(scaling_config) 12 | 13 | def possibly_quantize_sigma(self, sigma): 14 | return sigma 15 | 16 | def possibly_quantize_c_noise(self, c_noise): 17 | return c_noise 18 | 19 | def w(self, sigma): 20 | return self.weighting(sigma) 21 | 22 | def __call__(self, network, input, sigma, cond): 23 | sigma = self.possibly_quantize_sigma(sigma) 24 | sigma_shape = sigma.shape 25 | sigma = append_dims(sigma, input.ndim) 26 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 27 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 28 | return network(input * c_in, c_noise, cond) * c_out + input * c_skip 29 | 30 | 31 | class DiscreteDenoiser(Denoiser): 32 | def __init__( 33 | self, 34 | weighting_config, 35 | scaling_config, 36 | num_idx, 37 | discretization_config, 38 | do_append_zero=False, 39 | quantize_c_noise=True, 40 | flip=True, 41 | ): 42 | super().__init__(weighting_config, scaling_config) 43 | sigmas = instantiate_from_config(discretization_config)( 44 | num_idx, do_append_zero=do_append_zero, flip=flip 45 | ) 46 | self.register_buffer("sigmas", sigmas) 47 | self.quantize_c_noise = quantize_c_noise 48 | 49 | def sigma_to_idx(self, sigma): 50 | dists = sigma - self.sigmas[:, None] 51 | return dists.abs().argmin(dim=0).view(sigma.shape) 52 | 53 | def idx_to_sigma(self, idx): 54 | return self.sigmas[idx] 55 | 56 | def possibly_quantize_sigma(self, sigma): 57 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 58 | 59 | def possibly_quantize_c_noise(self, c_noise): 60 | if self.quantize_c_noise: 61 | return self.sigma_to_idx(c_noise) 62 | else: 63 | return c_noise 64 | 65 | 66 | class DiscreteDenoiserWithControl(DiscreteDenoiser): 67 | def __call__(self, network, input, sigma, cond, control_scale): 68 | sigma = self.possibly_quantize_sigma(sigma) 69 | sigma_shape = sigma.shape 70 | sigma = append_dims(sigma, input.ndim) 71 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 72 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 73 | return network(input * c_in, c_noise, cond, control_scale) * c_out + input * c_skip 74 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EDMScaling: 5 | def __init__(self, sigma_data=0.5): 6 | self.sigma_data = sigma_data 7 | 8 | def __call__(self, sigma): 9 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 10 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 11 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 12 | c_noise = 0.25 * sigma.log() 13 | return c_skip, c_out, c_in, c_noise 14 | 15 | 16 | class EpsScaling: 17 | def __call__(self, sigma): 18 | c_skip = torch.ones_like(sigma, device=sigma.device) 19 | c_out = -sigma 20 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 21 | c_noise = sigma.clone() 22 | return c_skip, c_out, c_in, c_noise 23 | 24 | 25 | class VScaling: 26 | def __call__(self, sigma): 27 | c_skip = 1.0 / (sigma**2 + 1.0) 28 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 29 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 30 | c_noise = sigma.clone() 31 | return c_skip, c_out, c_in, c_noise 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class UnitWeighting: 4 | def __call__(self, sigma): 5 | return torch.ones_like(sigma, device=sigma.device) 6 | 7 | 8 | class EDMWeighting: 9 | def __init__(self, sigma_data=0.5): 10 | self.sigma_data = sigma_data 11 | 12 | def __call__(self, sigma): 13 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 14 | 15 | 16 | class VWeighting(EDMWeighting): 17 | def __init__(self): 18 | super().__init__(sigma_data=1.0) 19 | 20 | 21 | class EpsWeighting: 22 | def __call__(self, sigma): 23 | return sigma**-2 24 | 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | from ...util import default, instantiate_from_config 6 | 7 | 8 | class VanillaCFG: 9 | """ 10 | implements parallelized CFG 11 | """ 12 | 13 | def __init__(self, scale, dyn_thresh_config=None): 14 | scale_schedule = lambda scale, sigma: scale # independent of step 15 | self.scale_schedule = partial(scale_schedule, scale) 16 | self.dyn_thresh = instantiate_from_config( 17 | default( 18 | dyn_thresh_config, 19 | { 20 | "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" 21 | }, 22 | ) 23 | ) 24 | 25 | def __call__(self, x, sigma): 26 | x_u, x_c = x.chunk(2) 27 | scale_value = self.scale_schedule(sigma) 28 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 29 | return x_pred 30 | 31 | def prepare_inputs(self, x, s, c, uc): 32 | c_out = dict() 33 | 34 | for k in c: 35 | if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']: 36 | c_out[k] = torch.cat((uc[k], c[k]), 0) 37 | else: 38 | assert c[k] == uc[k] 39 | c_out[k] = c[k] 40 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 41 | 42 | 43 | 44 | class LinearCFG: 45 | def __init__(self, scale, scale_min=None, dyn_thresh_config=None): 46 | if scale_min is None: 47 | scale_min = scale 48 | scale_schedule = lambda scale, scale_min, sigma: (scale - scale_min) * sigma / 14.6146 + scale_min 49 | self.scale_schedule = partial(scale_schedule, scale, scale_min) 50 | self.dyn_thresh = instantiate_from_config( 51 | default( 52 | dyn_thresh_config, 53 | { 54 | "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" 55 | }, 56 | ) 57 | ) 58 | 59 | def __call__(self, x, sigma): 60 | x_u, x_c = x.chunk(2) 61 | scale_value = self.scale_schedule(sigma) 62 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 63 | return x_pred 64 | 65 | def prepare_inputs(self, x, s, c, uc): 66 | c_out = dict() 67 | 68 | for k in c: 69 | if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']: 70 | c_out[k] = torch.cat((uc[k], c[k]), 0) 71 | else: 72 | assert c[k] == uc[k] 73 | c_out[k] = c[k] 74 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 75 | 76 | 77 | 78 | class IdentityGuider: 79 | def __call__(self, x, sigma): 80 | return x 81 | 82 | def prepare_inputs(self, x, s, c, uc): 83 | c_out = dict() 84 | 85 | for k in c: 86 | c_out[k] = c[k] 87 | 88 | return x, s, c_out 89 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from omegaconf import ListConfig 6 | 7 | from ...util import append_dims, instantiate_from_config 8 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 9 | 10 | 11 | class StandardDiffusionLoss(nn.Module): 12 | def __init__( 13 | self, 14 | sigma_sampler_config, 15 | type="l2", 16 | offset_noise_level=0.0, 17 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, 18 | ): 19 | super().__init__() 20 | 21 | assert type in ["l2", "l1", "lpips"] 22 | 23 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 24 | 25 | self.type = type 26 | self.offset_noise_level = offset_noise_level 27 | 28 | if type == "lpips": 29 | self.lpips = LPIPS().eval() 30 | 31 | if not batch2model_keys: 32 | batch2model_keys = [] 33 | 34 | if isinstance(batch2model_keys, str): 35 | batch2model_keys = [batch2model_keys] 36 | 37 | self.batch2model_keys = set(batch2model_keys) 38 | 39 | def __call__(self, network, denoiser, conditioner, input, batch): 40 | cond = conditioner(batch) 41 | additional_model_inputs = { 42 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 43 | } 44 | 45 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device) 46 | noise = torch.randn_like(input) 47 | if self.offset_noise_level > 0.0: 48 | noise = noise + self.offset_noise_level * append_dims( 49 | torch.randn(input.shape[0], device=input.device), input.ndim 50 | ) 51 | noised_input = input + noise * append_dims(sigmas, input.ndim) 52 | model_output = denoiser( 53 | network, noised_input, sigmas, cond, **additional_model_inputs 54 | ) 55 | w = append_dims(denoiser.w(sigmas), input.ndim) 56 | return self.get_loss(model_output, input, w) 57 | 58 | def get_loss(self, model_output, target, w): 59 | if self.type == "l2": 60 | return torch.mean( 61 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 62 | ) 63 | elif self.type == "l1": 64 | return torch.mean( 65 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 66 | ) 67 | elif self.type == "lpips": 68 | loss = self.lpips(model_output, target).reshape(-1) 69 | return loss 70 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | class NoDynamicThresholding: 8 | def __call__(self, uncond, cond, scale): 9 | return uncond + scale.view(-1, 1, 1, 1) * (cond - uncond) 10 | 11 | 12 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 13 | if order - 1 > i: 14 | raise ValueError(f"Order {order} too high for step {i}") 15 | 16 | def fn(tau): 17 | prod = 1.0 18 | for k in range(order): 19 | if j == k: 20 | continue 21 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 22 | return prod 23 | 24 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 25 | 26 | 27 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 28 | if not eta: 29 | return sigma_to, 0.0 30 | sigma_up = torch.minimum( 31 | sigma_to, 32 | eta 33 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 34 | ) 35 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 36 | return sigma_down, sigma_up 37 | 38 | 39 | def to_d(x, sigma, denoised): 40 | return (x - denoised) / append_dims(sigma, x.ndim) 41 | 42 | 43 | def to_neg_log_sigma(sigma): 44 | return sigma.log().neg() 45 | 46 | 47 | def to_sigma(neg_log_sigma): 48 | return neg_log_sigma.neg().exp() 49 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, idx_range=None): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | self.idx_range = idx_range 23 | 24 | def idx_to_sigma(self, idx): 25 | # print(self.sigmas[idx]) 26 | return self.sigmas[idx] 27 | 28 | def __call__(self, n_samples, rand=None): 29 | if self.idx_range is None: 30 | idx = default( 31 | rand, 32 | torch.randint(0, self.num_idx, (n_samples,)), 33 | ) 34 | else: 35 | idx = default( 36 | rand, 37 | torch.randint(self.idx_range[0], self.idx_range[1], (n_samples,)), 38 | ) 39 | return self.idx_to_sigma(idx) 40 | 41 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | # import torch._dynamo 5 | # torch._dynamo.config.suppress_errors = True 6 | # torch._dynamo.config.cache_size_limit = 512 7 | 8 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 9 | 10 | 11 | class IdentityWrapper(nn.Module): 12 | def __init__(self, diffusion_model, compile_model: bool = False): 13 | super().__init__() 14 | compile = ( 15 | torch.compile 16 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 17 | and compile_model 18 | else lambda x: x 19 | ) 20 | self.diffusion_model = compile(diffusion_model) 21 | 22 | def forward(self, *args, **kwargs): 23 | return self.diffusion_model(*args, **kwargs) 24 | 25 | 26 | class OpenAIWrapper(IdentityWrapper): 27 | def forward( 28 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 29 | ) -> torch.Tensor: 30 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 31 | return self.diffusion_model( 32 | x, 33 | timesteps=t, 34 | context=c.get("crossattn", None), 35 | y=c.get("vector", None), 36 | **kwargs, 37 | ) 38 | 39 | 40 | class OpenAIHalfWrapper(IdentityWrapper): 41 | def __init__(self, *args, **kwargs): 42 | super().__init__(*args, **kwargs) 43 | self.diffusion_model = self.diffusion_model.half() 44 | 45 | def forward( 46 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 47 | ) -> torch.Tensor: 48 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 49 | _context = c.get("crossattn", None) 50 | _y = c.get("vector", None) 51 | if _context is not None: 52 | _context = _context.half() 53 | if _y is not None: 54 | _y = _y.half() 55 | x = x.half() 56 | t = t.half() 57 | 58 | out = self.diffusion_model( 59 | x, 60 | timesteps=t, 61 | context=_context, 62 | y=_y, 63 | **kwargs, 64 | ) 65 | return out.float() 66 | 67 | 68 | class ControlWrapper(nn.Module): 69 | def __init__(self, diffusion_model, compile_model: bool = False, dtype=torch.float32): 70 | super().__init__() 71 | self.compile = ( 72 | torch.compile 73 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 74 | and compile_model 75 | else lambda x: x 76 | ) 77 | self.diffusion_model = self.compile(diffusion_model) 78 | self.control_model = None 79 | self.dtype = dtype 80 | 81 | def load_control_model(self, control_model): 82 | self.control_model = self.compile(control_model) 83 | 84 | def forward( 85 | self, x: torch.Tensor, t: torch.Tensor, c: dict, control_scale=1, **kwargs 86 | ) -> torch.Tensor: 87 | with torch.autocast("cuda", dtype=self.dtype): 88 | control = self.control_model(x=c.get("control", None), timesteps=t, xt=x, 89 | control_vector=c.get("control_vector", None), 90 | mask_x=c.get("mask_x", None), 91 | context=c.get("crossattn", None), 92 | y=c.get("vector", None)) 93 | out = self.diffusion_model( 94 | x, 95 | timesteps=t, 96 | context=c.get("crossattn", None), 97 | y=c.get("vector", None), 98 | control=control, 99 | control_scale=control_scale, 100 | **kwargs, 101 | ) 102 | return out.float() 103 | 104 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxwh/SUPIR/a5d4389fdb73abca6c286d3f4283080f4d8b0479/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | import argparse 3 | from SUPIR.util import create_SUPIR_model, PIL2Tensor, Tensor2PIL, convert_dtype 4 | from PIL import Image 5 | from llava.llava_agent import LLavaAgent 6 | from CKPT_PTH import LLAVA_MODEL_PATH 7 | import os 8 | if torch.cuda.device_count() >= 2: 9 | SUPIR_device = 'cuda:0' 10 | LLaVA_device = 'cuda:1' 11 | elif torch.cuda.device_count() == 1: 12 | SUPIR_device = 'cuda:0' 13 | LLaVA_device = 'cuda:0' 14 | else: 15 | raise ValueError('Currently support CUDA only.') 16 | 17 | # hyparams here 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--img_dir", type=str) 20 | parser.add_argument("--save_dir", type=str) 21 | parser.add_argument("--upscale", type=int, default=1) 22 | parser.add_argument("--SUPIR_sign", type=str, default='Q', choices=['F', 'Q']) 23 | parser.add_argument("--seed", type=int, default=1234) 24 | parser.add_argument("--min_size", type=int, default=1024) 25 | parser.add_argument("--edm_steps", type=int, default=50) 26 | parser.add_argument("--s_stage1", type=int, default=-1) 27 | parser.add_argument("--s_churn", type=int, default=5) 28 | parser.add_argument("--s_noise", type=float, default=1.003) 29 | parser.add_argument("--s_cfg", type=float, default=7.5) 30 | parser.add_argument("--s_stage2", type=float, default=1.) 31 | parser.add_argument("--num_samples", type=int, default=1) 32 | parser.add_argument("--a_prompt", type=str, 33 | default='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R ' 34 | 'camera, hyper detailed photo - realistic maximum detail, 32k, Color ' 35 | 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, ' 36 | 'hyper sharpness, perfect without deformations.') 37 | parser.add_argument("--n_prompt", type=str, 38 | default='painting, oil painting, illustration, drawing, art, sketch, oil painting, ' 39 | 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, ' 40 | 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, ' 41 | 'deformed, lowres, over-smooth') 42 | parser.add_argument("--color_fix_type", type=str, default='Wavelet', choices=["None", "AdaIn", "Wavelet"]) 43 | parser.add_argument("--linear_CFG", action='store_true', default=False) 44 | parser.add_argument("--linear_s_stage2", action='store_true', default=False) 45 | parser.add_argument("--spt_linear_CFG", type=float, default=1.0) 46 | parser.add_argument("--spt_linear_s_stage2", type=float, default=0.) 47 | parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16']) 48 | parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16']) 49 | parser.add_argument("--no_llava", action='store_true', default=False) 50 | args = parser.parse_args() 51 | print(args) 52 | use_llava = not args.no_llava 53 | 54 | # load SUPIR 55 | model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device) 56 | model.ae_dtype = convert_dtype(args.ae_dtype) 57 | model.model.dtype = convert_dtype(args.diff_dtype) 58 | # load LLaVA 59 | if use_llava: 60 | llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device) 61 | else: 62 | llava_agent = None 63 | 64 | os.makedirs(args.save_dir, exist_ok=True) 65 | for img_pth in os.listdir(args.img_dir): 66 | img_name = os.path.splitext(img_pth)[0] 67 | 68 | LQ_img = Image.open(os.path.join(args.img_dir, img_pth)) 69 | LQ_img, h0, w0 = PIL2Tensor(LQ_img, upsacle=args.upscale, min_size=args.min_size) 70 | LQ_img = LQ_img.unsqueeze(0).to(SUPIR_device)[:, :3, :, :] 71 | 72 | # step 1: Pre-denoise for LLaVA) 73 | clean_imgs = model.batchify_denoise(LQ_img) 74 | clean_PIL_img = Tensor2PIL(clean_imgs[0], h0, w0) 75 | 76 | # step 2: LLaVA 77 | if use_llava: 78 | captions = llava_agent.gen_image_caption([clean_PIL_img]) 79 | else: 80 | captions = [''] 81 | print(captions) 82 | 83 | # # step 3: Diffusion Process 84 | samples = model.batchify_sample(LQ_img, captions, num_steps=args.edm_steps, restoration_scale=args.s_stage1, s_churn=args.s_churn, 85 | s_noise=args.s_noise, cfg_scale=args.s_cfg, control_scale=args.s_stage2, seed=args.seed, 86 | num_samples=args.num_samples, p_p=args.a_prompt, n_p=args.n_prompt, color_fix_type=args.color_fix_type, 87 | use_linear_CFG=args.linear_CFG, use_linear_control_scale=args.linear_s_stage2, 88 | cfg_scale_start=args.spt_linear_CFG, control_scale_start=args.spt_linear_s_stage2) 89 | # save 90 | for _i, sample in enumerate(samples): 91 | Tensor2PIL(sample, h0, w0).save(f'{args.save_dir}/{img_name}_{_i}.png') 92 | 93 | --------------------------------------------------------------------------------