├── llava ├── __init__.py ├── model │ ├── __init__.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── language_model │ │ ├── float_utils.py │ │ └── llava_llama.py │ ├── utils.py │ ├── multimodal_projector │ │ └── builder.py │ ├── builder.py │ └── llava_arch.py ├── constants.py ├── utils.py ├── eval │ ├── inference.py │ └── evaluate.py ├── mm_utils.py ├── train │ ├── llava_trainer.py │ └── train.py └── conversation.py ├── Dockerfile ├── .gitignore ├── environment.yml ├── scripts └── zero2.json ├── inference.py ├── train.py ├── README.md └── LICENSE /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mambaorg/micromamba:1.5.8-focal-cuda-12.3.1 2 | COPY --chown=$MAMBA_USER:$MAMBA_USER environment.yml /tmp/environment.yml 3 | RUN micromamba install -y -n base -f /tmp/environment.yml \ 4 | && micromamba clean --all --yes 5 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | except: 6 | pass 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | *.json 11 | *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | 16 | # Editor 17 | .idea 18 | *.swp 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | 25 | checkpoints 26 | ckpts* 27 | 28 | .ipynb_checkpoints 29 | *.ipynb 30 | 31 | # DevContainer 32 | !.devcontainer/* 33 | 34 | # Demo 35 | serve_images/ 36 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | 15 | FLOAT_TOKEN = "\u6570" # 30354 数 16 | FLOAT_TOKEN_ID = 30354 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: igllm 2 | channels: 3 | - nvidia 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - python 8 | - pip 9 | - git 10 | - accelerate 11 | - deepspeed 12 | - einops 13 | - fastapi 14 | - gradio 15 | - ipykernel 16 | - numpy 17 | - pandas 18 | - pytorch::pytorch 19 | - pytorch::pytorch-cuda=12.1 20 | - requests 21 | - scipy 22 | - sentencepiece 23 | - tqdm 24 | - transformers 25 | - uvicorn 26 | - wandb 27 | - xformers::xformers 28 | - pip: 29 | - bitsandbytes 30 | - peft 31 | - sglang[all] 32 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /llava/model/language_model/float_utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from transformers.activations import GELUActivation 3 | import torch 4 | 5 | 6 | def get_float_head(config): 7 | if config.float_head_type == "linear": 8 | return nn.Linear(config.hidden_size, 1, bias=True) 9 | elif config.float_head_type == "tanh_mlp_gelu": 10 | return nn.Sequential( 11 | nn.Tanh(), 12 | nn.Linear(config.hidden_size, config.hidden_size, bias=True), 13 | GELUActivation(), 14 | nn.Linear(config.hidden_size, 1, bias=True), 15 | ) 16 | else: 17 | print("Not using a float head:", config.float_head_type) 18 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "gradient_clipping": "auto", 17 | "zero_optimization": { 18 | "stage": 2, 19 | "overlap_comm": true, 20 | "contiguous_gradients": true, 21 | "sub_group_size": 1e9, 22 | "reduce_bucket_size": "auto" 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | from pathlib import Path 5 | from subprocess import run 6 | import json 7 | import gzip 8 | 9 | ROOT_DIR = Path("/is/cluster/fast/pkulits/code/IG-LLM") 10 | assert ROOT_DIR.exists() 11 | 12 | prompt_version = "v1" 13 | model_version = "336px-pretrain-vicuna-7b-v1.3" 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--images_tar", type=str) 17 | parser.add_argument("--model-path", type=str, required=True) 18 | parser.add_argument("--model-base", type=str, default="lmsys/vicuna-7b-v1.3") 19 | parser.add_argument("--model-name", type=str, default="llava-lora") 20 | parser.add_argument("--image-folder", type=str, default="/tmp/images") 21 | parser.add_argument("--out_path", type=str) 22 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 23 | parser.add_argument("--slice-start", type=int, default=None) 24 | parser.add_argument("--slice-end", type=int, default=None) 25 | parser.add_argument("--slice-step", type=int, default=None) 26 | parser.add_argument("--image_aspect_ratio", type=str, default="square") 27 | parser.add_argument("--is_2d", default=0, type=int) 28 | parser.add_argument("--num_beams", default=1, type=int) 29 | parser.add_argument("--annotations_path", type=str, default=None) 30 | 31 | args = parser.parse_args() 32 | 33 | assert (Path(ROOT_DIR) / args.out_path).parent.exists(), args.out_path 34 | 35 | assert args.is_2d or args.images_tar is not None, args.images_tar 36 | assert args.is_2d or Path(args.images_tar).exists(), args.images_tar 37 | 38 | if args.annotations_path is not None: 39 | assert Path(args.annotations_path).exists(), args.annotations_path 40 | 41 | if not args.is_2d: 42 | Path("/tmp/images").mkdir() 43 | 44 | if args.annotations_path is not None: 45 | with gzip.open(args.annotations_path) as f: 46 | image_ids = {a["image_id"] for a in json.load(f)[slice(args.slice_start, args.slice_end, args.slice_step)]} 47 | 48 | print(len(image_ids), flush=True) 49 | 50 | print("Extracting images", flush=True) 51 | run(["tar", "xf", args.images_tar, "-C", "/tmp/images/"], check=True) 52 | 53 | if args.annotations_path is not None: 54 | for image_path in Path("/tmp/images").glob("*"): 55 | if image_path.name not in image_ids: 56 | image_path.unlink() 57 | elif args.slice_start is not None or args.slice_end is not None or args.slice_step is not None: 58 | valid_paths = sorted(Path("/tmp/images").glob("*"))[slice(args.slice_start, args.slice_end, args.slice_step)] 59 | for image_path in Path("/tmp/images").glob("*"): 60 | if image_path not in valid_paths: 61 | image_path.unlink() 62 | 63 | print("Extracted images", flush=True) 64 | 65 | del args.images_tar 66 | del args.slice_start 67 | del args.slice_end 68 | del args.slice_step 69 | del args.annotations_path 70 | 71 | run( 72 | [ 73 | "python", 74 | "llava/eval/inference.py", 75 | ] 76 | + [f"--{k}={v}" for k, v in vars(args).items() if v is not None], 77 | env={ 78 | "PYTHONPATH": ".", 79 | }, 80 | check=True, 81 | ) 82 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | 24 | def load_model(self, device_map=None): 25 | if self.is_loaded: 26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 27 | return 28 | 29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 30 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 31 | self.vision_tower.requires_grad_(False) 32 | 33 | self.is_loaded = True 34 | 35 | def feature_select(self, image_forward_outs): 36 | image_features = image_forward_outs.hidden_states[self.select_layer] 37 | if self.select_feature == 'patch': 38 | image_features = image_features[:, 1:] 39 | elif self.select_feature == 'cls_patch': 40 | image_features = image_features 41 | else: 42 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 43 | return image_features 44 | 45 | @torch.no_grad() 46 | def forward(self, images): 47 | if type(images) is list: 48 | image_features = [] 49 | for image in images: 50 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 51 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 52 | image_features.append(image_feature) 53 | else: 54 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 55 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 56 | 57 | return image_features 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.device 70 | 71 | @property 72 | def config(self): 73 | if self.is_loaded: 74 | return self.vision_tower.config 75 | else: 76 | return self.cfg_only 77 | 78 | @property 79 | def hidden_size(self): 80 | return self.config.hidden_size 81 | 82 | @property 83 | def num_patches_per_side(self): 84 | return self.config.image_size // self.config.patch_size 85 | 86 | @property 87 | def num_patches(self): 88 | return (self.config.image_size // self.config.patch_size) ** 2 89 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from subprocess import run 3 | import argparse 4 | import socket 5 | 6 | print(__file__) 7 | 8 | with socket.socket() as s: 9 | s.bind(("", 0)) 10 | main_process_port = s.getsockname()[1] 11 | 12 | prompt_version = "v1" 13 | model_version = "336px-pretrain-vicuna-7b-v1.3" 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--images_tar", type=str) 17 | parser.add_argument("--images_val_tar", type=str) 18 | parser.add_argument("--data_path", type=str, required=True) 19 | parser.add_argument("--data_path_val", type=str) 20 | parser.add_argument("--output_dir", type=str, required=True) 21 | parser.add_argument("--deepspeed", default="./scripts/zero2.json", type=str) 22 | parser.add_argument("--lora_enable", default=True, type=bool) 23 | parser.add_argument("--lora_r", default=128, type=int) 24 | parser.add_argument("--lora_alpha", default=256, type=int) 25 | parser.add_argument("--model_name_or_path", default="lmsys/vicuna-7b-v1.3", type=str) 26 | parser.add_argument("--version", default=prompt_version, type=str) 27 | parser.add_argument("--image_folder", default="/tmp/images", type=str) 28 | parser.add_argument("--image_folder_val", default="/tmp/images_val", type=str) 29 | parser.add_argument("--vision_tower", default="openai/clip-vit-large-patch14-336", type=str) 30 | parser.add_argument( 31 | "--pretrain_mm_mlp_adapter", default=f"./checkpoints/llava-{model_version}/mm_projector.bin", type=str 32 | ) 33 | parser.add_argument("--mm_vision_select_layer", default=-2, type=int) 34 | parser.add_argument("--mm_use_im_start_end", default=False, type=bool) 35 | parser.add_argument("--mm_use_im_patch_token", default=False, type=bool) 36 | parser.add_argument("--bf16", default=True, type=bool) 37 | parser.add_argument("--bits", default=16, type=int) 38 | parser.add_argument("--num_train_epochs", default=30, type=int) 39 | parser.add_argument("--max_steps", type=int) 40 | parser.add_argument("--per_device_train_batch_size", default=32, type=int) 41 | parser.add_argument("--per_device_eval_batch_size", default=32, type=int) 42 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int) 43 | parser.add_argument("--save_strategy", default="steps", type=str) 44 | parser.add_argument("--save_steps", default=500, type=int) 45 | parser.add_argument("--save_total_limit", default=1, type=int) 46 | parser.add_argument("--learning_rate", default=2e-5, type=float) 47 | parser.add_argument("--weight_decay", default=0.0, type=float) 48 | parser.add_argument("--mm_projector_lr", default=2e-5, type=float) 49 | parser.add_argument("--float_head_lr", default=2e-4, type=float) 50 | parser.add_argument("--warmup_ratio", default=0.03, type=float) 51 | parser.add_argument("--lr_scheduler_type", default="cosine", type=str) 52 | parser.add_argument("--logging_steps", default=1, type=int) 53 | parser.add_argument("--tf32", default=True, type=bool) 54 | parser.add_argument("--model_max_length", default=2048, type=int) 55 | parser.add_argument("--gradient_checkpointing", default=True, type=bool) 56 | parser.add_argument("--lazy_preprocess", default=True, type=bool) 57 | parser.add_argument("--dataloader_num_workers", default=20, type=int) 58 | parser.add_argument("--report_to", default="wandb", type=str) 59 | parser.add_argument("--shuffle_attributes", default=None, type=eval) 60 | parser.add_argument("--float_head_type", default="none", type=str) 61 | parser.add_argument("--evaluation_strategy", default="steps", type=str) 62 | parser.add_argument("--use_synonyms", default=True, type=bool) 63 | parser.add_argument("--eval_steps", default=100, type=int) 64 | parser.add_argument("--rotation_rep", type=str) 65 | parser.add_argument("--mm_projector_type", type=str) 66 | parser.add_argument("--image_aspect_ratio", default="square", type=str) 67 | parser.add_argument("--is_2d", default=False, type=bool) 68 | parser.add_argument("--num_samples", type=int) 69 | parser.add_argument("--float_w", type=float) 70 | 71 | args = parser.parse_args() 72 | 73 | assert args.is_2d or args.images_tar is not None, args.images_tar 74 | assert args.is_2d or Path(args.images_tar).exists(), args.images_tar 75 | assert args.is_2d or Path(args.data_path).exists(), args.data_path 76 | 77 | if not args.is_2d: 78 | Path("/tmp/images").mkdir() 79 | run(["tar", "xf", args.images_tar, "-C", "/tmp/images/"], check=True) 80 | run(["rsync", "-a", args.data_path, "/tmp/train.json"], check=True) 81 | args.data_path = "/tmp/train.json" 82 | 83 | if args.images_val_tar: 84 | assert Path(args.images_val_tar).exists(), args.images_val_tar 85 | assert Path(args.data_path_val).exists(), args.data_path_val 86 | Path("/tmp/images_val").mkdir() 87 | run(["tar", "xf", args.images_val_tar, "-C", "/tmp/images_val/"], check=True) 88 | run(["rsync", "-a", args.data_path_val, "/tmp/val.json"], check=True) 89 | args.data_path_val = "/tmp/val.json" 90 | else: 91 | print("No images_val_tar passed.") 92 | 93 | del args.images_tar 94 | del args.images_val_tar 95 | 96 | run( 97 | [ 98 | "deepspeed", 99 | "--master_port", 100 | str(main_process_port), 101 | "llava/train/train.py", 102 | ] 103 | + [f"--{k}={v}" for k, v in vars(args).items() if v is not None], 104 | env={ 105 | "PYTHONPATH": ".", 106 | }, 107 | check=True, 108 | ) 109 | -------------------------------------------------------------------------------- /llava/eval/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import orjson 7 | import torch 8 | import tqdm.auto as tqdm 9 | from llava.constants import IMAGE_TOKEN_INDEX 10 | from llava.conversation import conv_templates 11 | from llava.mm_utils import get_model_name_from_path, tokenizer_image_token 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from PIL import Image, UnidentifiedImageError 15 | from llava.train.train import draw_dot 16 | 17 | 18 | def main(args): 19 | disable_torch_init() 20 | model_path = os.path.expanduser(args.model_path) 21 | if args.model_name: 22 | model_name = args.model_name 23 | else: 24 | model_name = get_model_name_from_path(model_path) 25 | tokenizer, model, image_processor, _ = load_pretrained_model( 26 | model_path, args.model_base, model_name, device_map=args.device 27 | ) 28 | 29 | answers = {} 30 | for image_path in tqdm.tqdm( 31 | sorted(Path(args.image_folder).rglob("*")) 32 | if not args.is_2d 33 | else np.load("data/2d.npz")["random"] 34 | ): 35 | if args.is_2d: 36 | x, y = image_path 37 | image = draw_dot(x, y).convert("RGB") 38 | else: 39 | try: 40 | image = Image.open(image_path) 41 | except (IsADirectoryError, UnidentifiedImageError): 42 | continue 43 | 44 | conv = conv_templates[args.conv_mode].copy() 45 | conv.append_message( 46 | conv.roles[0], 47 | "\nWhat Python Blender code could be used to produce the scene?", 48 | ) 49 | conv.append_message(conv.roles[1], None) 50 | prompt = conv.get_prompt() 51 | 52 | input_ids = ( 53 | tokenizer_image_token( 54 | prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" 55 | ) 56 | .unsqueeze(0) 57 | .cuda() 58 | ) 59 | 60 | image = image.convert("RGB") 61 | if args.image_aspect_ratio == "pad": 62 | 63 | def expand2square(pil_img, background_color): 64 | width, height = pil_img.size 65 | if width == height: 66 | return pil_img 67 | elif width > height: 68 | result = Image.new(pil_img.mode, (width, width), background_color) 69 | result.paste(pil_img, (0, (width - height) // 2)) 70 | return result 71 | else: 72 | result = Image.new(pil_img.mode, (height, height), background_color) 73 | result.paste(pil_img, ((height - width) // 2, 0)) 74 | return result 75 | 76 | image = expand2square( 77 | image, tuple(int(x * 255) for x in image_processor.image_mean) 78 | ) 79 | image_tensor = image_processor.preprocess(image, return_tensors="pt")[ 80 | "pixel_values" 81 | ][0] 82 | else: 83 | image_tensor = image_processor.preprocess(image, return_tensors="pt")[ 84 | "pixel_values" 85 | ][0] 86 | 87 | with torch.inference_mode(): 88 | output_ids = model.generate( 89 | input_ids, 90 | images=image_tensor.unsqueeze(0).half().cuda(), 91 | do_sample=False, 92 | temperature=0, 93 | top_p=None, 94 | num_beams=args.num_beams, 95 | max_new_tokens=1024, 96 | use_cache=True, 97 | tokenizer=tokenizer, 98 | bad_words_ids=tokenizer( 99 | [ 100 | "(\n", 101 | "( ", 102 | " )", 103 | ], 104 | add_special_tokens=False, 105 | ).input_ids, 106 | ) 107 | 108 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 109 | print(image_path, outputs, flush=True) 110 | 111 | answers[f"{image_path}"] = { 112 | "outputs": outputs, 113 | } 114 | 115 | Path(args.out_path).write_bytes( 116 | orjson.dumps( 117 | { 118 | "model_path": args.model_path, 119 | "model_base": args.model_base, 120 | "model_name": args.model_name, 121 | "answers": answers, 122 | }, 123 | option=orjson.OPT_SERIALIZE_NUMPY, 124 | ) 125 | ) 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--model_path", type=str, default="facebook/opt-350m") 131 | parser.add_argument("--model_base", type=str, default=None) 132 | parser.add_argument("--model_name", type=str, default=None) 133 | parser.add_argument("--image_folder", type=str, default="") 134 | parser.add_argument("--out_path", type=str) 135 | parser.add_argument("--conv_mode", type=str, default="llava_v1") 136 | parser.add_argument("--image_aspect_ratio", type=str, default="square") 137 | parser.add_argument("--is_2d", default=0, type=int) 138 | parser.add_argument("--num_beams", default=1, type=int) 139 | parser.add_argument("--device", default="cuda:0", type=str) 140 | args = parser.parse_args() 141 | 142 | if os.path.exists(args.out_path): 143 | print(f"File {args.out_path} already exists. Skipping.") 144 | exit(0) 145 | 146 | main(args) 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Re-Thinking Inverse Graphics With Large Language Models

2 | 3 |

Peter Kulits*, Haiwen Feng*, Weiyang Liu, Victoria Abrevaya, Michael J. Black

4 | 5 |

[Project Page] [TMLR]

6 | 7 |

Summary

8 | We present the Inverse-Graphics Large Language Model (IG-LLM) framework, a general approach to solving inverse-graphics problems. We instruction-tune an LLM to decode a visual (CLIP) embedding into graphics code that can be used to reproduce the observed scene using a standard graphics engine. Leveraging the broad reasoning abilities of LLMs, we demonstrate that our framework exhibits natural generalization across a variety of distribution shifts without the use of special inductive biases. 9 |

10 | 11 | ![image](https://ig-llm.is.tue.mpg.de/media/upload/header.jpeg) 12 | 13 |

Data

14 | Training and evaluation data can be found at https://ig-llm.is.tue.mpg.de/download.php after registering on the project page. The following is an outline of the data available: 15 |
16 | 17 | ```sh 18 | ├── CLEVR 19 | │   ├── images 20 | │   │   ├── train.tar 21 | │   │   ├── val_ID.tar 22 | │   │   └── val_OOD.tar 23 | │   └── labels 24 | │   ├── train.json 25 | │   ├── val_ID.json 26 | │   └── val_OOD.json 27 | ├── 2D 28 | │   └── 2d.npz 29 | ├── SO3 30 | │ ├── images 31 | │ │   ├── train.tar 32 | │ │   ├── val_ID.tar 33 | │ │   └── val_OOD.tar 34 | │ └── labels 35 | │ ├── train.json 36 | │ ├── val_ID.json 37 | │ └── val_OOD.json 38 | ├── 6DoF 39 | │   ├── images 40 | │   │   ├── train.tar 41 | │   │   └── val_ID.tar 42 | │   └── labels 43 | │   ├── train.json 44 | │   └── val_ID.json 45 | └── ShapeNet 46 |    ├── images 47 |    │   ├── train.tar 48 |    │   ├── val_ID.tar 49 |    │   ├── val_OOD_texture.tar 50 |    │   └── val_OOD_shape.tar 51 |    └── labels 52 |    ├── train.json 53 |    ├── val_ID.json 54 |    ├── val_OOD_texture.json 55 |    └── val_OOD_shape.json 56 | ``` 57 |
58 | 59 |

Setup

60 | The environment can be configured with conda/micromamba from environment.yml or using the Dockerfile. 61 | 62 |

Training

63 | After the data has been downloaded, training can be initiated with the following: 64 | 65 |
    66 |
  • CLEVR 67 |
    68 | 69 | ```sh 70 | python train.py \ 71 | --images_tar data/CLEVR/images/train.tar \ 72 | --data_path data/CLEVR/images/train.json \ 73 | --images_val_tar data/CLEVR/images/val_OOD.tar \ 74 | --data_path_val data/CLEVR/labels/val_OOD.json \ 75 | --per_device_train_batch_size X \ 76 | --output_dir ./checkpoints/clevr-Y \ 77 | --max_steps 40000 \ 78 | --float_head_type (none|tanh_mlp_gelu) \ 79 | --image_aspect_ratio pad \ 80 | --num_samples 4000 81 | ``` 82 |
    83 |
  • 84 | 85 |
  • 2D 86 |
    87 | 2d.npz is expected to be at data/2d.npz prior to running train.py. 88 | 89 | ```sh 90 | python train.py \ 91 | --data_path checkerboard_sparse \ 92 | --data_path_val random \ 93 | --per_device_train_batch_size X \ 94 | --output_dir ./checkpoints/2d-Y \ 95 | --max_steps 40000 \ 96 | --float_head_type (none|tanh_mlp_gelu) \ 97 | --image_aspect_ratio pad \ 98 | --is_2d True 99 | ``` 100 |
    101 |
  • 102 | 103 |
  • SO(3) 104 |
    105 | 106 | ```sh 107 | python train.py \ 108 | --images_tar data/SO3/images/train.tar \ 109 | --data_path data/SO3/images/train.json \ 110 | --images_val_tar data/SO3/images/val_OOD.tar \ 111 | --data_path_val data/SO3/labels/val_OOD.json \ 112 | --per_device_train_batch_size X \ 113 | --output_dir ./checkpoints/so3-Y \ 114 | --max_steps 40000 \ 115 | --float_head_type (none|tanh_mlp_gelu) \ 116 | --image_aspect_ratio pad \ 117 | --rotation_rep (euler_int|euler|aa|6d) 118 | ``` 119 |
    120 |
  • 121 | 122 |
  • 6-DoF 123 |
    124 | 125 | ```sh 126 | python train.py \ 127 | --images_tar data/6DoF/images/train.tar \ 128 | --data_path data/6DoF/images/train.json \ 129 | --images_val_tar data/6DoF/images/val_ID.tar \ 130 | --data_path_val data/6DoF/labels/val_ID.json \ 131 | --per_device_train_batch_size X \ 132 | --output_dir ./checkpoints/6dof-Y \ 133 | --max_steps 200000 \ 134 | --float_head_type (none|tanh_mlp_gelu) \ 135 | --image_aspect_ratio pad \ 136 | --rotation_rep (euler_int|euler|aa|6d) 137 | ``` 138 |
    139 |
  • 140 | 141 |
  • ShapeNet 142 |
    143 | 144 | ```sh 145 | python train.py \ 146 | --images_tar data/ShapeNet/images/train.tar \ 147 | --data_path data/ShapeNet/images/train.json \ 148 | --images_val_tar data/ShapeNet/images/val_OOD_texture.tar \ 149 | --data_path_val data/ShapeNet/labels/val_OOD_texture.json \ 150 | --per_device_train_batch_size X \ 151 | --output_dir ./checkpoints/shapenet-Y \ 152 | --max_steps 500000 \ 153 | --float_head_type (none|tanh_mlp_gelu) \ 154 | --image_aspect_ratio pad \ 155 | --rotation_rep (euler_int|euler|aa|6d) 156 | ``` 157 |
    158 |
  • 159 |
160 | 161 |

Inference

162 | 163 | ```sh 164 | python inference.py \ 165 | --model-path ./checkpoints/clevr-Y \ 166 | --images_tar data/CLEVR/images/val_OOD.tar \ 167 | --out_path ./out/clevr-Y-val_OOD.json \ 168 | --image_aspect_ratio pad 169 | ``` 170 | 171 |

License

172 | We build off the LLaVA codebase to perform our experiments. As such, inherited code falls under the original Apache 2.0 license. Additions and modifications are released under a different license in accordance with institute requirements which has been prepended to LICENSE. 173 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | from pathlib import Path 20 | 21 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 22 | import torch 23 | from llava.model import * 24 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | 26 | 27 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 28 | kwargs = {"device_map": device_map, **kwargs} 29 | 30 | if device != "cuda": 31 | kwargs['device_map'] = {"": device} 32 | 33 | if load_8bit: 34 | kwargs['load_in_8bit'] = True 35 | elif load_4bit: 36 | kwargs['load_in_4bit'] = True 37 | kwargs['quantization_config'] = BitsAndBytesConfig( 38 | load_in_4bit=True, 39 | bnb_4bit_compute_dtype=torch.float16, 40 | bnb_4bit_use_double_quant=True, 41 | bnb_4bit_quant_type='nf4' 42 | ) 43 | else: 44 | kwargs['torch_dtype'] = torch.float16 45 | 46 | if use_flash_attn: 47 | kwargs['attn_implementation'] = 'flash_attention_2' 48 | 49 | if 'llava' in model_name.lower(): 50 | # Load LLaVA model 51 | if 'lora' in model_name.lower() and model_base is None: 52 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 53 | if 'lora' in model_name.lower() and model_base is not None: 54 | from llava.model.language_model.llava_llama import LlavaConfig 55 | if Path(model_path).name.startswith('checkpoint'): 56 | lora_cfg_pretrained = AutoConfig.from_pretrained(str(Path(model_path).parent)) 57 | else: 58 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 59 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 60 | print('Loading LLaVA from base model...') 61 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 62 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 63 | if model.lm_head.weight.shape[0] != token_num: 64 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 65 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 66 | 67 | print('Loading additional LLaVA weights...') 68 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 69 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 70 | elif os.path.exists(model_path): # Non-final checkpoints don't follow the above format 71 | p = next(Path(model_path).glob("global_*/mp_rank_00_model_states.pt")) 72 | nlt_keys = ['base_model.model.model.mm_projector.weight', 'base_model.model.model.mm_projector.bias', 'base_model.model.float_head.original_module.weight', 'base_model.model.float_head.modules_to_save.default.weight'] 73 | non_lora_trainables = {k: v.float() for k, v in torch.load(p)['module'].items() if k in nlt_keys or 'float_head' in k} 74 | else: 75 | # this is probably from HF Hub 76 | from huggingface_hub import hf_hub_download 77 | def load_from_hf(repo_id, filename, subfolder=None): 78 | cache_file = hf_hub_download( 79 | repo_id=repo_id, 80 | filename=filename, 81 | subfolder=subfolder) 82 | return torch.load(cache_file, map_location='cpu') 83 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 84 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 85 | if any(k.startswith('model.model.') for k in non_lora_trainables): 86 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 87 | model.load_state_dict(non_lora_trainables, strict=False) 88 | 89 | from peft import PeftModel 90 | print('Loading LoRA weights...') 91 | model = PeftModel.from_pretrained(model, model_path) 92 | print('Merging LoRA weights...') 93 | model = model.merge_and_unload() 94 | print('Model is loaded...') 95 | elif model_base is not None: 96 | # this may be mm projector only 97 | print('Loading LLaVA from base model...') 98 | if 'mpt' in model_name.lower(): 99 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 100 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 101 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 102 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 103 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 104 | else: 105 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 106 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 107 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 108 | 109 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 110 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 111 | model.load_state_dict(mm_projector_weights, strict=False) 112 | else: 113 | if 'mpt' in model_name.lower(): 114 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 115 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 116 | elif 'mistral' in model_name.lower(): 117 | tokenizer = AutoTokenizer.from_pretrained(model_path) 118 | model = LlavaMistralForCausalLM.from_pretrained( 119 | model_path, 120 | low_cpu_mem_usage=True, 121 | **kwargs 122 | ) 123 | else: 124 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 125 | model = LlavaLlamaForCausalLM.from_pretrained( 126 | model_path, 127 | low_cpu_mem_usage=True, 128 | **kwargs 129 | ) 130 | else: 131 | # Load language model 132 | if model_base is not None: 133 | # PEFT model 134 | from peft import PeftModel 135 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 136 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 137 | print(f"Loading LoRA weights from {model_path}") 138 | model = PeftModel.from_pretrained(model, model_path) 139 | print(f"Merging weights") 140 | model = model.merge_and_unload() 141 | print('Convert to FP16...') 142 | model.to(torch.float16) 143 | else: 144 | use_fast = False 145 | if 'mpt' in model_name.lower(): 146 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 147 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 148 | else: 149 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 150 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 151 | 152 | image_processor = None 153 | 154 | if 'llava' in model_name.lower(): 155 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 156 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 157 | if mm_use_im_patch_token: 158 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 159 | if mm_use_im_start_end: 160 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 161 | model.resize_token_embeddings(len(tokenizer)) 162 | 163 | vision_tower = model.get_vision_tower() 164 | if not vision_tower.is_loaded: 165 | vision_tower.load_model(device_map=device_map) 166 | if device_map != 'auto': 167 | vision_tower.to(device=device_map, dtype=torch.float16) 168 | image_processor = vision_tower.image_processor 169 | 170 | if hasattr(model.config, "max_sequence_length"): 171 | context_len = model.config.max_sequence_length 172 | else: 173 | context_len = 2048 174 | 175 | return tokenizer, model, image_processor, context_len 176 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from llava.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | else: 179 | return image_processor(images, return_tensors='pt')['pixel_values'] 180 | if all(x.shape == new_images[0].shape for x in new_images): 181 | new_images = torch.stack(new_images, dim=0) 182 | return new_images 183 | 184 | 185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 187 | 188 | def insert_separator(X, sep): 189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 190 | 191 | input_ids = [] 192 | offset = 0 193 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 194 | offset = 1 195 | input_ids.append(prompt_chunks[0][0]) 196 | 197 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 198 | input_ids.extend(x[offset:]) 199 | 200 | if return_tensors is not None: 201 | if return_tensors == 'pt': 202 | return torch.tensor(input_ids, dtype=torch.long) 203 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 204 | return input_ids 205 | 206 | 207 | def get_model_name_from_path(model_path): 208 | model_path = model_path.strip("/") 209 | model_paths = model_path.split("/") 210 | if model_paths[-1].startswith('checkpoint-'): 211 | return model_paths[-2] + "_" + model_paths[-1] 212 | else: 213 | return model_paths[-1] 214 | 215 | class KeywordsStoppingCriteria(StoppingCriteria): 216 | def __init__(self, keywords, tokenizer, input_ids): 217 | self.keywords = keywords 218 | self.keyword_ids = [] 219 | self.max_keyword_len = 0 220 | for keyword in keywords: 221 | cur_keyword_ids = tokenizer(keyword).input_ids 222 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 223 | cur_keyword_ids = cur_keyword_ids[1:] 224 | if len(cur_keyword_ids) > self.max_keyword_len: 225 | self.max_keyword_len = len(cur_keyword_ids) 226 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 227 | self.tokenizer = tokenizer 228 | self.start_len = input_ids.shape[1] 229 | 230 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 231 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 232 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 233 | for keyword_id in self.keyword_ids: 234 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 235 | if torch.equal(truncated_output_ids, keyword_id): 236 | return True 237 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 238 | for keyword in self.keywords: 239 | if keyword in outputs: 240 | return True 241 | return False 242 | 243 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 244 | outputs = [] 245 | for i in range(output_ids.shape[0]): 246 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 247 | return all(outputs) 248 | -------------------------------------------------------------------------------- /llava/eval/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import numpy as np 4 | import orjson 5 | import re 6 | import scipy.optimize 7 | import torch 8 | import tqdm.auto as tqdm 9 | from scipy.spatial.transform import Rotation as R 10 | 11 | synonyms = { 12 | "sphere": ["sphere", "ball"], 13 | "cube": ["cube", "block"], 14 | "large": ["large", "big"], 15 | "small": ["small", "tiny"], 16 | "metal": ["metallic", "metal", "shiny"], 17 | "rubber": ["rubber", "matte"], 18 | "code": ["code", "Python code", "Python", "Python script", "script"], 19 | "produce": ["produce", "create", "generate", "synthesize"], 20 | } 21 | 22 | synonyms_inv = {n: k for k, v in synonyms.items() for n in v} 23 | 24 | def compute_geodesic_distance_from_two_matrices(m1, m2): 25 | m = np.matmul(m1, m2.transpose(0, 2, 1)) # batch*3*3 26 | cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 27 | cos = np.minimum(cos, np.ones(cos.shape)) 28 | cos = np.maximum(cos, np.ones(cos.shape) * -1) 29 | return np.arccos(cos) 30 | 31 | 32 | def or_none(x): 33 | return x[0] if x else None 34 | 35 | 36 | def parse(text, n_rot=1): 37 | objects = [] 38 | 39 | rot_s = "rotation=" 40 | if n_rot != 1: 41 | rot_s += "\(" 42 | rot_s += ", ".join(["([^,)]+)"] * n_rot) 43 | if n_rot != 1: 44 | rot_s += "\)" 45 | 46 | for line in text.split("\n"): 47 | line = line.strip() 48 | if not line.startswith("add("): 49 | continue 50 | 51 | objects.append( 52 | { 53 | "size": or_none(re.findall("size='([^',]+)'", line)), 54 | "color": or_none(re.findall("color='([^',]+)'", line)), 55 | "material": or_none(re.findall("material='([^',]+)'", line)), 56 | "shape": or_none(re.findall("shape='([^',]+)'", line)), 57 | "3d_coords": or_none( 58 | re.findall("loc=\(([^,)]+), ([^,)]+), ([^,)]+)\)", line) 59 | ), 60 | "rotation": or_none(re.findall(rot_s, line)), 61 | } 62 | ) 63 | 64 | return { 65 | "objects": objects, 66 | } 67 | 68 | 69 | def to_num(n): 70 | try: 71 | return float(n) 72 | except: 73 | return 0 74 | 75 | 76 | def match_euclidean(objects, objects_pred): 77 | cost_matrix = scipy.spatial.distance.cdist( 78 | np.array([n["3d_coords"] for n in objects]), 79 | np.array([n["3d_coords"] for n in objects_pred]), 80 | ) 81 | row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix.T) 82 | return row_ind, col_ind 83 | 84 | 85 | def main(args): 86 | args.out_folder.mkdir(exist_ok=True) 87 | 88 | if args.rotation_rep == "auto": 89 | if args.pred_path.stem.startswith("clevr"): 90 | args.rotation_rep = "yaw" 91 | elif "-euler-" in args.pred_path.stem: 92 | args.rotation_rep = "euler" 93 | elif "-6d-" in args.pred_path.stem: 94 | args.rotation_rep = "6d" 95 | elif "-aa-" in args.pred_path.stem: 96 | args.rotation_rep = "aa" 97 | elif "-euler_int-" in args.pred_path.stem: 98 | args.rotation_rep = "euler_int" 99 | else: 100 | raise ValueError( 101 | f"Unknown rotation representation in {args.pred_path.stem}" 102 | ) 103 | 104 | n_rot = { 105 | "yaw": 1, 106 | "euler": 3, 107 | "euler_int": 3, 108 | "6d": 6, 109 | "aa": 3, 110 | }[args.rotation_rep] 111 | print(f'{args.rotation_rep=}, {n_rot=}') 112 | 113 | gt = orjson.loads(args.gt_path.read_bytes()) 114 | if isinstance(gt, dict): 115 | gt = gt["scenes"] 116 | 117 | pred = orjson.loads(args.pred_path.read_bytes())["answers"] 118 | 119 | mse = [] 120 | l2 = [] 121 | count_acc = [] 122 | shape_acc = [] 123 | color_acc = [] 124 | material_acc = [] 125 | size_acc = [] 126 | shape_cls_acc = [] 127 | so3_relative_angles = [] 128 | count_diff = [] 129 | 130 | key_to_gt = {scene_gt["image_filename"]: scene_gt for scene_gt in gt} 131 | 132 | for k, v in pred.items(): 133 | v["scene_gt"] = key_to_gt[Path(k).name] 134 | 135 | for k, v in tqdm.tqdm(pred.items()): 136 | scene_gt = v["scene_gt"] 137 | scene_pred = parse(v["outputs"], n_rot=n_rot) 138 | 139 | if len(scene_pred["objects"]) != len(scene_gt["objects"]): 140 | count_acc.append(0) 141 | count_diff.append(len(scene_pred["objects"]) - len(scene_gt["objects"])) 142 | else: 143 | count_acc.append(1) 144 | count_diff.append(0) 145 | 146 | for obj in scene_pred["objects"]: 147 | for k, v in obj.items(): 148 | if k != "3d_coords" and k != "rotation": 149 | obj[k] = synonyms_inv.get(v, v) 150 | 151 | try: 152 | obj["3d_coords"] = [to_num(n) for n in obj["3d_coords"]] 153 | except (ValueError, TypeError): 154 | obj["3d_coords"] = [0, 0, 0] 155 | 156 | try: 157 | obj["rotation"] = [to_num(n) for n in obj["rotation"]] 158 | except (ValueError, TypeError): 159 | obj["rotation"] = [0.0] * n_rot 160 | 161 | if args.rotation_rep == "6d": 162 | r = np.array([to_num(n) for n in obj["rotation"]]).reshape(2, 3) 163 | obj["rotation"] = R.from_matrix(np.vstack([r, np.cross(r[0], r[1])[None]])).as_euler('xyz').tolist() 164 | elif args.rotation_rep == "aa": 165 | obj["rotation"] = ( 166 | R.from_rotvec(obj["rotation"]).as_euler("xyz").tolist() 167 | ) 168 | elif args.rotation_rep == "euler_int": 169 | obj["rotation"] = R.from_euler("XYZ", obj["rotation"]).as_euler("xyz").tolist() 170 | 171 | row_ind, col_ind = match_euclidean(scene_gt["objects"], scene_pred["objects"]) 172 | objects_pred = [scene_pred["objects"][i] for i in row_ind] 173 | objects_gt = [scene_gt["objects"][i] for i in col_ind] 174 | 175 | try: 176 | shape_acc.append( 177 | np.mean( 178 | [a["shape"] == b["shape"] for a, b in zip(objects_gt, objects_pred) if a['shape'].lower() != 'sphere'] 179 | ) 180 | ) 181 | except KeyError: 182 | shape_acc.append(np.nan) 183 | 184 | try: 185 | color_acc.append( 186 | np.mean( 187 | [a["color"] == b["color"] for a, b in zip(objects_gt, objects_pred)] 188 | ) 189 | ) 190 | except KeyError: 191 | color_acc.append(np.nan) 192 | 193 | try: 194 | material_acc.append( 195 | np.mean( 196 | [ 197 | a["material"] == b["material"] 198 | for a, b in zip(objects_gt, objects_pred) 199 | ] 200 | ) 201 | ) 202 | except KeyError: 203 | material_acc.append(np.nan) 204 | 205 | try: 206 | size_acc.append( 207 | np.mean( 208 | [a["size"] == b["size"] for a, b in zip(objects_gt, objects_pred)] 209 | ) 210 | ) 211 | except KeyError: 212 | size_acc.append(np.nan) 213 | 214 | try: 215 | shape_cls_acc.append( 216 | np.mean( 217 | [ 218 | a["shape"].split("_")[0] == str(b["shape"]).split("_")[0] 219 | for a, b in zip(objects_gt, objects_pred) 220 | ] 221 | ) 222 | ) 223 | except KeyError: 224 | shape_cls_acc.append(np.nan) 225 | 226 | mse.append( 227 | torch.nn.functional.mse_loss( 228 | torch.tensor([n["3d_coords"] for n in objects_gt]), 229 | torch.tensor([n["3d_coords"] for n in objects_pred]), 230 | ) 231 | ) 232 | l2.append( 233 | torch.nn.functional.pairwise_distance( 234 | torch.tensor([n["3d_coords"] for n in objects_gt]), 235 | torch.tensor([n["3d_coords"] for n in objects_pred]), 236 | p=2, 237 | ).mean() 238 | ) 239 | 240 | if n_rot > 1: 241 | so3_relative_angles.append( 242 | ( 243 | compute_geodesic_distance_from_two_matrices( 244 | np.array( 245 | [ 246 | R.from_euler("xyz", n["rotation"]).as_matrix() 247 | for n in objects_gt 248 | ] 249 | ), 250 | np.array( 251 | [ 252 | R.from_euler("xyz", n["rotation"]).as_matrix() 253 | for n in objects_pred 254 | ] 255 | ), 256 | ) 257 | .mean() 258 | .item() 259 | / np.pi 260 | * 180 261 | ) 262 | ) 263 | 264 | print( 265 | *(f'{k}={v:.2f}' for k, v in { 266 | "l2": torch.tensor([n for n in l2 if n]).mean().item(), 267 | "geod": torch.tensor(so3_relative_angles).mean().item(), 268 | "ace": np.abs(count_diff).mean(), 269 | "size": 100*torch.tensor(size_acc).to(torch.float).mean().item(), 270 | "color": 100*torch.tensor(color_acc).to(torch.float).mean().item(), 271 | "mat.": 100*torch.tensor(material_acc).to(torch.float).mean().item(), 272 | "shape": 100*torch.tensor(shape_acc).to(torch.float).nanmean().item(), 273 | "shp_cls": 100*torch.tensor(shape_cls_acc).to(torch.float).mean().item(), 274 | }.items()) 275 | ) 276 | 277 | 278 | if __name__ == "__main__": 279 | parser = argparse.ArgumentParser() 280 | parser.add_argument("--gt_path", type=Path, required=True) 281 | parser.add_argument("--pred_path", type=Path, required=True) 282 | parser.add_argument("--out_folder", type=Path, default=Path("./eval")) 283 | parser.add_argument("--rotation_rep", type=str, default="auto") 284 | args = parser.parse_args() 285 | 286 | assert args.gt_path.exists(), f"File {args.gt_path} does not exist." 287 | assert args.pred_path.exists(), f"File {args.pred_path} does not exist." 288 | 289 | main(args) 290 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from dataclasses import dataclass 17 | import sys 18 | from typing import Dict, List, Optional, Tuple, Union 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | from transformers import AutoConfig, AutoModelForCausalLM, \ 25 | LlamaConfig, LlamaModel, LlamaForCausalLM 26 | 27 | from transformers.modeling_outputs import CausalLMOutputWithPast 28 | from transformers.generation.utils import GenerateOutput 29 | 30 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 31 | from llava.constants import FLOAT_TOKEN_ID, FLOAT_TOKEN 32 | from llava.model.language_model.float_utils import get_float_head 33 | 34 | 35 | class LlavaConfig(LlamaConfig): 36 | model_type = "llava_llama" 37 | float_head_type: Optional[str] = None 38 | float_w: float = 1.0 39 | 40 | 41 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config: LlamaConfig): 45 | super(LlavaLlamaModel, self).__init__(config) 46 | 47 | 48 | @dataclass 49 | class FloatsCausalLMOutputWithPast(CausalLMOutputWithPast): 50 | floats_pred: Optional[Tuple[torch.FloatTensor]] = None 51 | logs: Optional[Dict[str, float]] = None 52 | 53 | 54 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 55 | config_class = LlavaConfig 56 | 57 | def __init__(self, config): 58 | super(LlamaForCausalLM, self).__init__(config) 59 | self.model = LlavaLlamaModel(config) 60 | self.pretraining_tp = config.pretraining_tp 61 | self.vocab_size = config.vocab_size 62 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 63 | self.float_head = get_float_head(config) 64 | 65 | # Initialize weights and apply final processing 66 | self.post_init() 67 | 68 | def get_model(self): 69 | return self.model 70 | 71 | def forward( 72 | self, 73 | input_ids: torch.LongTensor = None, 74 | attention_mask: Optional[torch.Tensor] = None, 75 | position_ids: Optional[torch.LongTensor] = None, 76 | past_key_values: Optional[List[torch.FloatTensor]] = None, 77 | inputs_embeds: Optional[torch.FloatTensor] = None, 78 | labels: Optional[torch.LongTensor] = None, 79 | use_cache: Optional[bool] = None, 80 | output_attentions: Optional[bool] = None, 81 | output_hidden_states: Optional[bool] = None, 82 | images: Optional[torch.FloatTensor] = None, 83 | image_sizes: Optional[List[List[int]]] = None, 84 | return_dict: Optional[bool] = None, 85 | floats: Optional[torch.FloatTensor] = None, 86 | ) -> Union[Tuple, FloatsCausalLMOutputWithPast]: 87 | 88 | if inputs_embeds is None: 89 | ( 90 | input_ids, 91 | position_ids, 92 | attention_mask, 93 | past_key_values, 94 | inputs_embeds, 95 | labels 96 | ) = self.prepare_inputs_labels_for_multimodal( 97 | input_ids, 98 | position_ids, 99 | attention_mask, 100 | past_key_values, 101 | labels, 102 | images, 103 | image_sizes 104 | ) 105 | 106 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 107 | output_hidden_states = ( 108 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 109 | ) 110 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 111 | 112 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 113 | outputs = self.model( 114 | input_ids=input_ids, 115 | attention_mask=attention_mask, 116 | position_ids=position_ids, 117 | past_key_values=past_key_values, 118 | inputs_embeds=inputs_embeds, 119 | use_cache=use_cache, 120 | output_attentions=output_attentions, 121 | output_hidden_states=output_hidden_states, 122 | return_dict=return_dict, 123 | ) 124 | 125 | hidden_states = outputs[0] 126 | if self.config.pretraining_tp > 1: 127 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 128 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 129 | logits = torch.cat(logits, dim=-1) 130 | else: 131 | logits = self.lm_head(hidden_states) 132 | logits = logits.float() 133 | 134 | logs = {} 135 | loss = None 136 | if labels is not None: 137 | # Shift so that tokens < n predict n 138 | shift_logits = logits[..., :-1, :].contiguous() 139 | shift_labels = labels[..., 1:].contiguous() 140 | # Flatten the tokens 141 | loss_fct = nn.CrossEntropyLoss() 142 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 143 | shift_labels = shift_labels.view(-1) 144 | # Enable model parallelism 145 | shift_labels = shift_labels.to(shift_logits.device) 146 | loss = loss_fct(shift_logits, shift_labels) 147 | logs["next_token_loss"] = loss.detach().item() 148 | 149 | floats_pred = None 150 | if self.float_head is not None and floats is not None: 151 | floats_pred_all = self.float_head(hidden_states)[:, :-1, 0] 152 | floats_mask = labels[:, 1:] == FLOAT_TOKEN_ID # 数 153 | floats_pred = floats_pred_all[floats_mask.to(floats_pred_all.device)] 154 | 155 | if torch.is_tensor(floats) and torch.numel(floats): 156 | float_mse = torch.nn.functional.mse_loss(floats_pred, floats) 157 | loss += float_mse * self.config.float_w 158 | logs["float_mse_loss"] = float_mse.detach().item() 159 | 160 | for k, v in self.float_head.named_parameters(): # Sanity check to watch weights update 161 | abs_sum = v.detach().abs().float().sum().item() 162 | logs[f"abs_sum_{k}"] = abs_sum 163 | logs[f"hash_{k}"] = hash(str(abs_sum)) / sys.maxsize 164 | 165 | if not return_dict: 166 | output = (logits,) + outputs[1:] + (floats_pred,) 167 | return (loss,) + output if loss is not None else output 168 | 169 | return FloatsCausalLMOutputWithPast( 170 | loss=loss, 171 | logits=logits, 172 | past_key_values=outputs.past_key_values, 173 | hidden_states=outputs.hidden_states, 174 | attentions=outputs.attentions, 175 | floats_pred=floats_pred, 176 | logs=logs, 177 | ) 178 | 179 | @torch.no_grad() 180 | def generate( 181 | self, 182 | inputs: Optional[torch.Tensor] = None, 183 | images: Optional[torch.Tensor] = None, 184 | image_sizes: Optional[torch.Tensor] = None, 185 | *, 186 | tokenizer: Optional["PreTrainedTokenizerBase"] = None, 187 | **kwargs, 188 | ) -> Union[GenerateOutput, torch.LongTensor]: 189 | position_ids = kwargs.pop("position_ids", None) 190 | attention_mask = kwargs.pop("attention_mask", None) 191 | if "inputs_embeds" in kwargs: 192 | raise NotImplementedError("`inputs_embeds` is not supported") 193 | 194 | inputs_copy = inputs.clone() 195 | 196 | if images is not None: 197 | ( 198 | inputs, 199 | position_ids, 200 | attention_mask, 201 | _, 202 | inputs_embeds, 203 | _ 204 | ) = self.prepare_inputs_labels_for_multimodal( 205 | inputs, 206 | position_ids, 207 | attention_mask, 208 | None, 209 | None, 210 | images, 211 | image_sizes=image_sizes 212 | ) 213 | else: 214 | inputs_embeds = self.get_model().embed_tokens(inputs) 215 | 216 | output_ids = super().generate( 217 | position_ids=position_ids, 218 | attention_mask=attention_mask, 219 | inputs_embeds=inputs_embeds, 220 | **kwargs 221 | ) 222 | 223 | if (output_ids == FLOAT_TOKEN_ID).any(): # 数 224 | assert ( 225 | tokenizer is not None 226 | ), "Need `tokenizer` in `generate` if substituting floats" 227 | 228 | inputs = torch.cat([inputs_copy, output_ids[:, 1:]], dim=1) 229 | 230 | output_ids = torch.tensor( 231 | [ 232 | tokenizer( 233 | tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 234 | .replace(FLOAT_TOKEN, "{:.3f}") 235 | .format( 236 | *self( 237 | input_ids=inputs, 238 | labels=inputs, 239 | attention_mask=None, 240 | images=images, 241 | floats=True, 242 | ) 243 | .floats_pred.cpu() 244 | .tolist() 245 | ) 246 | ).input_ids 247 | ], 248 | device=output_ids.device, 249 | ) 250 | 251 | return output_ids 252 | 253 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 254 | inputs_embeds=None, **kwargs): 255 | images = kwargs.pop("images", None) 256 | image_sizes = kwargs.pop("image_sizes", None) 257 | inputs = super().prepare_inputs_for_generation( 258 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 259 | ) 260 | if images is not None: 261 | inputs['images'] = images 262 | if image_sizes is not None: 263 | inputs['image_sizes'] = image_sizes 264 | return inputs 265 | 266 | AutoConfig.register("llava_llama", LlavaConfig) 267 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 268 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.utils.data import Sampler 7 | 8 | from transformers import Trainer 9 | from transformers.trainer import ( 10 | is_sagemaker_mp_enabled, 11 | get_parameter_names, 12 | has_length, 13 | ALL_LAYERNORM_LAYERS, 14 | logger, 15 | ) 16 | from typing import List, Optional 17 | 18 | 19 | def maybe_zero_3(param, ignore_status=False, name=None): 20 | from deepspeed import zero 21 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 22 | if hasattr(param, "ds_id"): 23 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 24 | if not ignore_status: 25 | print(name, 'no ignore status') 26 | with zero.GatheredParameters([param]): 27 | param = param.data.detach().cpu().clone() 28 | else: 29 | param = param.detach().cpu().clone() 30 | return param 31 | 32 | 33 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 34 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 35 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 36 | return to_return 37 | 38 | 39 | def split_to_even_chunks(indices, lengths, num_chunks): 40 | """ 41 | Split a list of indices into `chunks` chunks of roughly equal lengths. 42 | """ 43 | 44 | if len(indices) % num_chunks != 0: 45 | return [indices[i::num_chunks] for i in range(num_chunks)] 46 | 47 | num_indices_per_chunk = len(indices) // num_chunks 48 | 49 | chunks = [[] for _ in range(num_chunks)] 50 | chunks_lengths = [0 for _ in range(num_chunks)] 51 | for index in indices: 52 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 53 | chunks[shortest_chunk].append(index) 54 | chunks_lengths[shortest_chunk] += lengths[index] 55 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 56 | chunks_lengths[shortest_chunk] = float("inf") 57 | 58 | return chunks 59 | 60 | 61 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 62 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 63 | assert all(l != 0 for l in lengths), "Should not have zero length." 64 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 65 | # all samples are in the same modality 66 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 67 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 68 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 69 | 70 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 71 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 72 | megabatch_size = world_size * batch_size 73 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 74 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 75 | 76 | last_mm = mm_megabatches[-1] 77 | last_lang = lang_megabatches[-1] 78 | additional_batch = last_mm + last_lang 79 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 80 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 81 | megabatches = [megabatches[i] for i in megabatch_indices] 82 | 83 | if len(additional_batch) > 0: 84 | megabatches.append(sorted(additional_batch)) 85 | 86 | return [i for megabatch in megabatches for i in megabatch] 87 | 88 | 89 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 90 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 91 | indices = torch.randperm(len(lengths), generator=generator) 92 | megabatch_size = world_size * batch_size 93 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 94 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 95 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 96 | 97 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 98 | 99 | 100 | class LengthGroupedSampler(Sampler): 101 | r""" 102 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 103 | keeping a bit of randomness. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | batch_size: int, 109 | world_size: int, 110 | lengths: Optional[List[int]] = None, 111 | generator=None, 112 | group_by_modality: bool = False, 113 | ): 114 | if lengths is None: 115 | raise ValueError("Lengths must be provided.") 116 | 117 | self.batch_size = batch_size 118 | self.world_size = world_size 119 | self.lengths = lengths 120 | self.generator = generator 121 | self.group_by_modality = group_by_modality 122 | 123 | def __len__(self): 124 | return len(self.lengths) 125 | 126 | def __iter__(self): 127 | if self.group_by_modality: 128 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 129 | else: 130 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 131 | return iter(indices) 132 | 133 | 134 | class LLaVATrainer(Trainer): 135 | 136 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 137 | if self.train_dataset is None or not has_length(self.train_dataset): 138 | return None 139 | 140 | if self.args.group_by_modality_length: 141 | lengths = self.train_dataset.modality_lengths 142 | return LengthGroupedSampler( 143 | self.args.train_batch_size, 144 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 145 | lengths=lengths, 146 | group_by_modality=True, 147 | ) 148 | else: 149 | return super()._get_train_sampler() 150 | 151 | def create_optimizer(self): 152 | """ 153 | Setup the optimizer. 154 | 155 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 156 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 157 | """ 158 | if is_sagemaker_mp_enabled(): 159 | return super().create_optimizer() 160 | 161 | opt_model = self.model 162 | 163 | if self.optimizer is None: 164 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 165 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 166 | if self.args.mm_projector_lr is not None: 167 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] 168 | float_head_parameters = [name for name, _ in opt_model.named_parameters() if "float_head" in name] 169 | optimizer_grouped_parameters = [ 170 | { 171 | "params": [ 172 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n not in float_head_parameters and p.requires_grad) 173 | ], 174 | "weight_decay": self.args.weight_decay, 175 | }, 176 | { 177 | "params": [ 178 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n not in float_head_parameters and p.requires_grad) 179 | ], 180 | "weight_decay": 0.0, 181 | }, 182 | { 183 | "params": [ 184 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 185 | ], 186 | "weight_decay": self.args.weight_decay, 187 | "lr": self.args.mm_projector_lr, 188 | }, 189 | { 190 | "params": [ 191 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 192 | ], 193 | "weight_decay": 0.0, 194 | "lr": self.args.mm_projector_lr, 195 | }, 196 | ] 197 | if float_head_parameters: 198 | optimizer_grouped_parameters += [ 199 | { 200 | "params": [ 201 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in float_head_parameters and p.requires_grad) 202 | ], 203 | "weight_decay": self.args.weight_decay, 204 | "lr": self.args.float_head_lr, 205 | }, 206 | { 207 | "params": [ 208 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in float_head_parameters and p.requires_grad) 209 | ], 210 | "weight_decay": 0.0, 211 | "lr": self.args.float_head_lr, 212 | }, 213 | ] 214 | else: 215 | optimizer_grouped_parameters = [ 216 | { 217 | "params": [ 218 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 219 | ], 220 | "weight_decay": self.args.weight_decay, 221 | }, 222 | { 223 | "params": [ 224 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 225 | ], 226 | "weight_decay": 0.0, 227 | }, 228 | ] 229 | 230 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 231 | 232 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 233 | if optimizer_cls.__name__ == "Adam8bit": 234 | import bitsandbytes 235 | 236 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 237 | 238 | skipped = 0 239 | for module in opt_model.modules(): 240 | if isinstance(module, nn.Embedding): 241 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 242 | logger.info(f"skipped {module}: {skipped/2**20}M params") 243 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 244 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 245 | logger.info(f"skipped: {skipped/2**20}M params") 246 | 247 | return self.optimizer 248 | 249 | def _save_checkpoint(self, model, trial, metrics=None): 250 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 251 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 252 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 253 | 254 | run_dir = self._get_output_dir(trial=trial) 255 | output_dir = os.path.join(run_dir, checkpoint_folder) 256 | 257 | # Only save Adapter 258 | keys_to_match = ['mm_projector', 'vision_resampler'] 259 | if getattr(self.args, "use_im_start_end", False): 260 | keys_to_match.extend(['embed_tokens', 'embed_in']) 261 | 262 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 263 | 264 | if self.args.local_rank == 0 or self.args.local_rank == -1: 265 | self.model.config.save_pretrained(output_dir) 266 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 267 | else: 268 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 269 | 270 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 271 | super()._save(output_dir=output_dir, state_dict=state_dict) 272 | 273 | from llava.train.train import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3 274 | 275 | output_dir = Path(output_dir) 276 | output_dir = str(output_dir.parent.joinpath("extra-"+output_dir.name)) 277 | 278 | state_dict_2 = get_peft_state_maybe_zero_3( 279 | self.model.named_parameters(), "none" 280 | ) 281 | non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( 282 | self.model.named_parameters() 283 | ) 284 | self.model.config.save_pretrained(output_dir) 285 | self.model.save_pretrained(output_dir, state_dict=state_dict_2) 286 | torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin')) 287 | 288 | def compute_loss(self, model, inputs, return_outputs=False): 289 | loss, outputs = super().compute_loss(model, inputs, return_outputs=True) 290 | if logs := outputs.logs: 291 | if not model.training: 292 | logs = {f'eval_{k}': v for k, v in logs.items()} 293 | self.log(logs) 294 | 295 | return (loss, outputs) if return_outputs else loss 296 | -------------------------------------------------------------------------------- /llava/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | SINGLE = auto() 12 | TWO = auto() 13 | MPT = auto() 14 | PLAIN = auto() 15 | LLAMA_2 = auto() 16 | 17 | 18 | @dataclasses.dataclass 19 | class Conversation: 20 | """A class that keeps all conversation history.""" 21 | system: str 22 | roles: List[str] 23 | messages: List[List[str]] 24 | offset: int 25 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 26 | sep: str = "###" 27 | sep2: str = None 28 | version: str = "Unknown" 29 | 30 | skip_next: bool = False 31 | 32 | def get_prompt(self): 33 | messages = self.messages 34 | if len(messages) > 0 and type(messages[0][1]) is tuple: 35 | messages = self.messages.copy() 36 | init_role, init_msg = messages[0].copy() 37 | init_msg = init_msg[0].replace("", "").strip() 38 | if 'mmtag' in self.version: 39 | messages[0] = (init_role, init_msg) 40 | messages.insert(0, (self.roles[0], "")) 41 | messages.insert(1, (self.roles[1], "Received.")) 42 | else: 43 | messages[0] = (init_role, "\n" + init_msg) 44 | 45 | if self.sep_style == SeparatorStyle.SINGLE: 46 | ret = self.system + self.sep 47 | for role, message in messages: 48 | if message: 49 | if type(message) is tuple: 50 | message, _, _ = message 51 | ret += role + ": " + message + self.sep 52 | else: 53 | ret += role + ":" 54 | elif self.sep_style == SeparatorStyle.TWO: 55 | seps = [self.sep, self.sep2] 56 | ret = self.system + seps[0] 57 | for i, (role, message) in enumerate(messages): 58 | if message: 59 | if type(message) is tuple: 60 | message, _, _ = message 61 | ret += role + ": " + message + seps[i % 2] 62 | else: 63 | ret += role + ":" 64 | elif self.sep_style == SeparatorStyle.MPT: 65 | ret = self.system + self.sep 66 | for role, message in messages: 67 | if message: 68 | if type(message) is tuple: 69 | message, _, _ = message 70 | ret += role + message + self.sep 71 | else: 72 | ret += role 73 | elif self.sep_style == SeparatorStyle.LLAMA_2: 74 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg 75 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 76 | ret = "" 77 | 78 | for i, (role, message) in enumerate(messages): 79 | if i == 0: 80 | assert message, "first message should not be none" 81 | assert role == self.roles[0], "first message should come from user" 82 | if message: 83 | if type(message) is tuple: 84 | message, _, _ = message 85 | if i == 0: message = wrap_sys(self.system) + message 86 | if i % 2 == 0: 87 | message = wrap_inst(message) 88 | ret += self.sep + message 89 | else: 90 | ret += " " + message + " " + self.sep2 91 | else: 92 | ret += "" 93 | ret = ret.lstrip(self.sep) 94 | elif self.sep_style == SeparatorStyle.PLAIN: 95 | seps = [self.sep, self.sep2] 96 | ret = self.system 97 | for i, (role, message) in enumerate(messages): 98 | if message: 99 | if type(message) is tuple: 100 | message, _, _ = message 101 | ret += message + seps[i % 2] 102 | else: 103 | ret += "" 104 | else: 105 | raise ValueError(f"Invalid style: {self.sep_style}") 106 | 107 | return ret 108 | 109 | def append_message(self, role, message): 110 | self.messages.append([role, message]) 111 | 112 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): 113 | if image_process_mode == "Pad": 114 | def expand2square(pil_img, background_color=(122, 116, 104)): 115 | width, height = pil_img.size 116 | if width == height: 117 | return pil_img 118 | elif width > height: 119 | result = Image.new(pil_img.mode, (width, width), background_color) 120 | result.paste(pil_img, (0, (width - height) // 2)) 121 | return result 122 | else: 123 | result = Image.new(pil_img.mode, (height, height), background_color) 124 | result.paste(pil_img, ((height - width) // 2, 0)) 125 | return result 126 | image = expand2square(image) 127 | elif image_process_mode in ["Default", "Crop"]: 128 | pass 129 | elif image_process_mode == "Resize": 130 | image = image.resize((336, 336)) 131 | else: 132 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 133 | if max(image.size) > max_len: 134 | max_hw, min_hw = max(image.size), min(image.size) 135 | aspect_ratio = max_hw / min_hw 136 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 137 | longest_edge = int(shortest_edge * aspect_ratio) 138 | W, H = image.size 139 | if H > W: 140 | H, W = longest_edge, shortest_edge 141 | else: 142 | H, W = shortest_edge, longest_edge 143 | image = image.resize((W, H)) 144 | if return_pil: 145 | return image 146 | else: 147 | buffered = BytesIO() 148 | image.save(buffered, format=image_format) 149 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 150 | return img_b64_str 151 | 152 | def get_images(self, return_pil=False): 153 | images = [] 154 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 155 | if i % 2 == 0: 156 | if type(msg) is tuple: 157 | msg, image, image_process_mode = msg 158 | image = self.process_image(image, image_process_mode, return_pil=return_pil) 159 | images.append(image) 160 | return images 161 | 162 | def to_gradio_chatbot(self): 163 | ret = [] 164 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 165 | if i % 2 == 0: 166 | if type(msg) is tuple: 167 | msg, image, image_process_mode = msg 168 | img_b64_str = self.process_image( 169 | image, "Default", return_pil=False, 170 | image_format='JPEG') 171 | img_str = f'user upload image' 172 | msg = img_str + msg.replace('', '').strip() 173 | ret.append([msg, None]) 174 | else: 175 | ret.append([msg, None]) 176 | else: 177 | ret[-1][-1] = msg 178 | return ret 179 | 180 | def copy(self): 181 | return Conversation( 182 | system=self.system, 183 | roles=self.roles, 184 | messages=[[x, y] for x, y in self.messages], 185 | offset=self.offset, 186 | sep_style=self.sep_style, 187 | sep=self.sep, 188 | sep2=self.sep2, 189 | version=self.version) 190 | 191 | def dict(self): 192 | if len(self.get_images()) > 0: 193 | return { 194 | "system": self.system, 195 | "roles": self.roles, 196 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 197 | "offset": self.offset, 198 | "sep": self.sep, 199 | "sep2": self.sep2, 200 | } 201 | return { 202 | "system": self.system, 203 | "roles": self.roles, 204 | "messages": self.messages, 205 | "offset": self.offset, 206 | "sep": self.sep, 207 | "sep2": self.sep2, 208 | } 209 | 210 | 211 | conv_vicuna_v0 = Conversation( 212 | system="A chat between a curious human and an artificial intelligence assistant. " 213 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 214 | roles=("Human", "Assistant"), 215 | messages=( 216 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 217 | ("Assistant", 218 | "Renewable energy sources are those that can be replenished naturally in a relatively " 219 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 220 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 221 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 222 | "renewable and non-renewable energy sources:\n" 223 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 224 | "energy sources are finite and will eventually run out.\n" 225 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 226 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 227 | "and other negative effects.\n" 228 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 229 | "have lower operational costs than non-renewable sources.\n" 230 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 231 | "locations than non-renewable sources.\n" 232 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 233 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 234 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 235 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 236 | ), 237 | offset=2, 238 | sep_style=SeparatorStyle.SINGLE, 239 | sep="###", 240 | ) 241 | 242 | conv_vicuna_v1 = Conversation( 243 | system="A chat between a curious user and an artificial intelligence assistant. " 244 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 245 | roles=("USER", "ASSISTANT"), 246 | version="v1", 247 | messages=(), 248 | offset=0, 249 | sep_style=SeparatorStyle.TWO, 250 | sep=" ", 251 | sep2="", 252 | ) 253 | 254 | conv_llama_2 = Conversation( 255 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 256 | 257 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", 258 | roles=("USER", "ASSISTANT"), 259 | version="llama_v2", 260 | messages=(), 261 | offset=0, 262 | sep_style=SeparatorStyle.LLAMA_2, 263 | sep="", 264 | sep2="", 265 | ) 266 | 267 | conv_llava_llama_2 = Conversation( 268 | system="You are a helpful language and vision assistant. " 269 | "You are able to understand the visual content that the user provides, " 270 | "and assist the user with a variety of tasks using natural language.", 271 | roles=("USER", "ASSISTANT"), 272 | version="llama_v2", 273 | messages=(), 274 | offset=0, 275 | sep_style=SeparatorStyle.LLAMA_2, 276 | sep="", 277 | sep2="", 278 | ) 279 | 280 | conv_mpt = Conversation( 281 | system="""<|im_start|>system 282 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 283 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 284 | version="mpt", 285 | messages=(), 286 | offset=0, 287 | sep_style=SeparatorStyle.MPT, 288 | sep="<|im_end|>", 289 | ) 290 | 291 | conv_llava_plain = Conversation( 292 | system="", 293 | roles=("", ""), 294 | messages=( 295 | ), 296 | offset=0, 297 | sep_style=SeparatorStyle.PLAIN, 298 | sep="\n", 299 | ) 300 | 301 | conv_llava_v0 = Conversation( 302 | system="A chat between a curious human and an artificial intelligence assistant. " 303 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 304 | roles=("Human", "Assistant"), 305 | messages=( 306 | ), 307 | offset=0, 308 | sep_style=SeparatorStyle.SINGLE, 309 | sep="###", 310 | ) 311 | 312 | conv_llava_v0_mmtag = Conversation( 313 | system="A chat between a curious user and an artificial intelligence assistant. " 314 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 315 | "The visual content will be provided with the following format: visual content.", 316 | roles=("Human", "Assistant"), 317 | messages=( 318 | ), 319 | offset=0, 320 | sep_style=SeparatorStyle.SINGLE, 321 | sep="###", 322 | version="v0_mmtag", 323 | ) 324 | 325 | conv_llava_v1 = Conversation( 326 | system="A chat between a curious human and an artificial intelligence assistant. " 327 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 328 | roles=("USER", "ASSISTANT"), 329 | version="v1", 330 | messages=(), 331 | offset=0, 332 | sep_style=SeparatorStyle.TWO, 333 | sep=" ", 334 | sep2="", 335 | ) 336 | 337 | conv_llava_v1_mmtag = Conversation( 338 | system="A chat between a curious user and an artificial intelligence assistant. " 339 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 340 | "The visual content will be provided with the following format: visual content.", 341 | roles=("USER", "ASSISTANT"), 342 | messages=(), 343 | offset=0, 344 | sep_style=SeparatorStyle.TWO, 345 | sep=" ", 346 | sep2="", 347 | version="v1_mmtag", 348 | ) 349 | 350 | conv_mistral_instruct = Conversation( 351 | system="", 352 | roles=("USER", "ASSISTANT"), 353 | version="llama_v2", 354 | messages=(), 355 | offset=0, 356 | sep_style=SeparatorStyle.LLAMA_2, 357 | sep="", 358 | sep2="", 359 | ) 360 | 361 | conv_chatml_direct = Conversation( 362 | system="""<|im_start|>system 363 | Answer the questions.""", 364 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 365 | version="mpt", 366 | messages=(), 367 | offset=0, 368 | sep_style=SeparatorStyle.MPT, 369 | sep="<|im_end|>", 370 | ) 371 | 372 | default_conversation = conv_vicuna_v1 373 | conv_templates = { 374 | "default": conv_vicuna_v0, 375 | "v0": conv_vicuna_v0, 376 | "v1": conv_vicuna_v1, 377 | "vicuna_v1": conv_vicuna_v1, 378 | "llama_2": conv_llama_2, 379 | "mistral_instruct": conv_mistral_instruct, 380 | "chatml_direct": conv_chatml_direct, 381 | "mistral_direct": conv_chatml_direct, 382 | 383 | "plain": conv_llava_plain, 384 | "v0_plain": conv_llava_plain, 385 | "llava_v0": conv_llava_v0, 386 | "v0_mmtag": conv_llava_v0_mmtag, 387 | "llava_v1": conv_llava_v1, 388 | "v1_mmtag": conv_llava_v1_mmtag, 389 | "llava_llama_2": conv_llava_llama_2, 390 | 391 | "mpt": conv_mpt, 392 | } 393 | 394 | 395 | if __name__ == "__main__": 396 | print(default_conversation.get_prompt()) 397 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # Data & Software Copyright License for non-commercial scientific research purposes 2 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use IG-LLM data, models, and software, (the "Data & Software"), including synthetic images and videos, SMPL and SMPL-X parameters, 3D body and clothing meshes, 2D textures, and scripts. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding code repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License. 3 | 4 | # Ownership / Licensees 5 | The Data & Software and the associated materials have been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI"). 6 | 7 | Any copyright or patent right is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter "MPG"; MPI and MPG hereinafter collectively "Max-Planck") hereinafter the "Licensor". 8 | 9 | # License Grant 10 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: 11 | 12 | + To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization; 13 | + To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects 14 | 15 | Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. 16 | 17 | The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Data & Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it. 18 | 19 | # No Distribution 20 | The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. 21 | 22 | # Disclaimer of Representations and Warranties 23 | You expressly acknowledge and agree that the Data & Software results from basic research, is provided "AS IS", may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party. 24 | 25 | # Limitation of Liability 26 | Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. 27 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. 28 | Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders. 29 | The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause. 30 | 31 | # No Maintenance Services 32 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time. 33 | 34 | Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification, or publication. 35 | 36 | # Publications using the Data & Software 37 | 38 | You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software. 39 | 40 | Citation: 41 | ``` 42 | @article{ 43 | kulits2024rethinking, 44 | title={Re-Thinking Inverse Graphics With Large Language Models}, 45 | author={Peter Kulits and Haiwen Feng and Weiyang Liu and Victoria Fernandez Abrevaya and Michael J. Black}, 46 | journal={Transactions on Machine Learning Research}, 47 | issn={2835-8856}, 48 | year={2024}, 49 | url={https://openreview.net/forum?id=u0eiu1MTS7}, 50 | } 51 | ``` 52 | 53 | # Commercial licensing opportunities 54 | For commercial use of the Data & Software, please send emails to ps-license@tue.mpg.de 55 | 56 | --- 57 | 58 | This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention. 59 | 60 | -------------------------------------------------------------------------------- 61 | 62 | Apache License 63 | Version 2.0, January 2004 64 | http://www.apache.org/licenses/ 65 | 66 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 67 | 68 | 1. Definitions. 69 | 70 | "License" shall mean the terms and conditions for use, reproduction, 71 | and distribution as defined by Sections 1 through 9 of this document. 72 | 73 | "Licensor" shall mean the copyright owner or entity authorized by 74 | the copyright owner that is granting the License. 75 | 76 | "Legal Entity" shall mean the union of the acting entity and all 77 | other entities that control, are controlled by, or are under common 78 | control with that entity. For the purposes of this definition, 79 | "control" means (i) the power, direct or indirect, to cause the 80 | direction or management of such entity, whether by contract or 81 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 82 | outstanding shares, or (iii) beneficial ownership of such entity. 83 | 84 | "You" (or "Your") shall mean an individual or Legal Entity 85 | exercising permissions granted by this License. 86 | 87 | "Source" form shall mean the preferred form for making modifications, 88 | including but not limited to software source code, documentation 89 | source, and configuration files. 90 | 91 | "Object" form shall mean any form resulting from mechanical 92 | transformation or translation of a Source form, including but 93 | not limited to compiled object code, generated documentation, 94 | and conversions to other media types. 95 | 96 | "Work" shall mean the work of authorship, whether in Source or 97 | Object form, made available under the License, as indicated by a 98 | copyright notice that is included in or attached to the work 99 | (an example is provided in the Appendix below). 100 | 101 | "Derivative Works" shall mean any work, whether in Source or Object 102 | form, that is based on (or derived from) the Work and for which the 103 | editorial revisions, annotations, elaborations, or other modifications 104 | represent, as a whole, an original work of authorship. For the purposes 105 | of this License, Derivative Works shall not include works that remain 106 | separable from, or merely link (or bind by name) to the interfaces of, 107 | the Work and Derivative Works thereof. 108 | 109 | "Contribution" shall mean any work of authorship, including 110 | the original version of the Work and any modifications or additions 111 | to that Work or Derivative Works thereof, that is intentionally 112 | submitted to Licensor for inclusion in the Work by the copyright owner 113 | or by an individual or Legal Entity authorized to submit on behalf of 114 | the copyright owner. For the purposes of this definition, "submitted" 115 | means any form of electronic, verbal, or written communication sent 116 | to the Licensor or its representatives, including but not limited to 117 | communication on electronic mailing lists, source code control systems, 118 | and issue tracking systems that are managed by, or on behalf of, the 119 | Licensor for the purpose of discussing and improving the Work, but 120 | excluding communication that is conspicuously marked or otherwise 121 | designated in writing by the copyright owner as "Not a Contribution." 122 | 123 | "Contributor" shall mean Licensor and any individual or Legal Entity 124 | on behalf of whom a Contribution has been received by Licensor and 125 | subsequently incorporated within the Work. 126 | 127 | 2. Grant of Copyright License. Subject to the terms and conditions of 128 | this License, each Contributor hereby grants to You a perpetual, 129 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 130 | copyright license to reproduce, prepare Derivative Works of, 131 | publicly display, publicly perform, sublicense, and distribute the 132 | Work and such Derivative Works in Source or Object form. 133 | 134 | 3. Grant of Patent License. Subject to the terms and conditions of 135 | this License, each Contributor hereby grants to You a perpetual, 136 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 137 | (except as stated in this section) patent license to make, have made, 138 | use, offer to sell, sell, import, and otherwise transfer the Work, 139 | where such license applies only to those patent claims licensable 140 | by such Contributor that are necessarily infringed by their 141 | Contribution(s) alone or by combination of their Contribution(s) 142 | with the Work to which such Contribution(s) was submitted. If You 143 | institute patent litigation against any entity (including a 144 | cross-claim or counterclaim in a lawsuit) alleging that the Work 145 | or a Contribution incorporated within the Work constitutes direct 146 | or contributory patent infringement, then any patent licenses 147 | granted to You under this License for that Work shall terminate 148 | as of the date such litigation is filed. 149 | 150 | 4. Redistribution. You may reproduce and distribute copies of the 151 | Work or Derivative Works thereof in any medium, with or without 152 | modifications, and in Source or Object form, provided that You 153 | meet the following conditions: 154 | 155 | (a) You must give any other recipients of the Work or 156 | Derivative Works a copy of this License; and 157 | 158 | (b) You must cause any modified files to carry prominent notices 159 | stating that You changed the files; and 160 | 161 | (c) You must retain, in the Source form of any Derivative Works 162 | that You distribute, all copyright, patent, trademark, and 163 | attribution notices from the Source form of the Work, 164 | excluding those notices that do not pertain to any part of 165 | the Derivative Works; and 166 | 167 | (d) If the Work includes a "NOTICE" text file as part of its 168 | distribution, then any Derivative Works that You distribute must 169 | include a readable copy of the attribution notices contained 170 | within such NOTICE file, excluding those notices that do not 171 | pertain to any part of the Derivative Works, in at least one 172 | of the following places: within a NOTICE text file distributed 173 | as part of the Derivative Works; within the Source form or 174 | documentation, if provided along with the Derivative Works; or, 175 | within a display generated by the Derivative Works, if and 176 | wherever such third-party notices normally appear. The contents 177 | of the NOTICE file are for informational purposes only and 178 | do not modify the License. You may add Your own attribution 179 | notices within Derivative Works that You distribute, alongside 180 | or as an addendum to the NOTICE text from the Work, provided 181 | that such additional attribution notices cannot be construed 182 | as modifying the License. 183 | 184 | You may add Your own copyright statement to Your modifications and 185 | may provide additional or different license terms and conditions 186 | for use, reproduction, or distribution of Your modifications, or 187 | for any such Derivative Works as a whole, provided Your use, 188 | reproduction, and distribution of the Work otherwise complies with 189 | the conditions stated in this License. 190 | 191 | 5. Submission of Contributions. Unless You explicitly state otherwise, 192 | any Contribution intentionally submitted for inclusion in the Work 193 | by You to the Licensor shall be under the terms and conditions of 194 | this License, without any additional terms or conditions. 195 | Notwithstanding the above, nothing herein shall supersede or modify 196 | the terms of any separate license agreement you may have executed 197 | with Licensor regarding such Contributions. 198 | 199 | 6. Trademarks. This License does not grant permission to use the trade 200 | names, trademarks, service marks, or product names of the Licensor, 201 | except as required for reasonable and customary use in describing the 202 | origin of the Work and reproducing the content of the NOTICE file. 203 | 204 | 7. Disclaimer of Warranty. Unless required by applicable law or 205 | agreed to in writing, Licensor provides the Work (and each 206 | Contributor provides its Contributions) on an "AS IS" BASIS, 207 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 208 | implied, including, without limitation, any warranties or conditions 209 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 210 | PARTICULAR PURPOSE. You are solely responsible for determining the 211 | appropriateness of using or redistributing the Work and assume any 212 | risks associated with Your exercise of permissions under this License. 213 | 214 | 8. Limitation of Liability. In no event and under no legal theory, 215 | whether in tort (including negligence), contract, or otherwise, 216 | unless required by applicable law (such as deliberate and grossly 217 | negligent acts) or agreed to in writing, shall any Contributor be 218 | liable to You for damages, including any direct, indirect, special, 219 | incidental, or consequential damages of any character arising as a 220 | result of this License or out of the use or inability to use the 221 | Work (including but not limited to damages for loss of goodwill, 222 | work stoppage, computer failure or malfunction, or any and all 223 | other commercial damages or losses), even if such Contributor 224 | has been advised of the possibility of such damages. 225 | 226 | 9. Accepting Warranty or Additional Liability. While redistributing 227 | the Work or Derivative Works thereof, You may choose to offer, 228 | and charge a fee for, acceptance of support, warranty, indemnity, 229 | or other liability obligations and/or rights consistent with this 230 | License. However, in accepting such obligations, You may act only 231 | on Your own behalf and on Your sole responsibility, not on behalf 232 | of any other Contributor, and only if You agree to indemnify, 233 | defend, and hold each Contributor harmless for any liability 234 | incurred by, or claims asserted against, such Contributor by reason 235 | of your accepting any such warranty or additional liability. 236 | 237 | END OF TERMS AND CONDITIONS 238 | 239 | APPENDIX: How to apply the Apache License to your work. 240 | 241 | To apply the Apache License to your work, attach the following 242 | boilerplate notice, with the fields enclosed by brackets "[]" 243 | replaced with your own identifying information. (Don't include 244 | the brackets!) The text should be enclosed in the appropriate 245 | comment syntax for the file format. We also recommend that a 246 | file or class name and description of purpose be included on the 247 | same "printed page" as the copyright notice for easier 248 | identification within third-party archives. 249 | 250 | Copyright [yyyy] [name of copyright owner] 251 | 252 | Licensed under the Apache License, Version 2.0 (the "License"); 253 | you may not use this file except in compliance with the License. 254 | You may obtain a copy of the License at 255 | 256 | http://www.apache.org/licenses/LICENSE-2.0 257 | 258 | Unless required by applicable law or agreed to in writing, software 259 | distributed under the License is distributed on an "AS IS" BASIS, 260 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 261 | See the License for the specific language governing permissions and 262 | limitations under the License. 263 | -------------------------------------------------------------------------------- /llava/model/llava_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from .multimodal_encoder.builder import build_vision_tower 22 | from .multimodal_projector.builder import build_vision_projector 23 | 24 | from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | 26 | from llava.mm_utils import get_anyres_image_grid_shape 27 | 28 | 29 | class LlavaMetaModel: 30 | 31 | def __init__(self, config): 32 | super(LlavaMetaModel, self).__init__(config) 33 | 34 | if hasattr(config, "mm_vision_tower"): 35 | self.vision_tower = build_vision_tower(config, delay_load=True) 36 | self.mm_projector = build_vision_projector(config) 37 | 38 | if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): 39 | self.image_newline = nn.Parameter( 40 | torch.empty(config.hidden_size, dtype=self.dtype) 41 | ) 42 | 43 | def get_vision_tower(self): 44 | vision_tower = getattr(self, 'vision_tower', None) 45 | if type(vision_tower) is list: 46 | vision_tower = vision_tower[0] 47 | return vision_tower 48 | 49 | def initialize_vision_modules(self, model_args, fsdp=None): 50 | vision_tower = model_args.vision_tower 51 | mm_vision_select_layer = model_args.mm_vision_select_layer 52 | mm_vision_select_feature = model_args.mm_vision_select_feature 53 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 54 | mm_patch_merge_type = model_args.mm_patch_merge_type 55 | 56 | self.config.mm_vision_tower = vision_tower 57 | 58 | if self.get_vision_tower() is None: 59 | vision_tower = build_vision_tower(model_args) 60 | 61 | if fsdp is not None and len(fsdp) > 0: 62 | self.vision_tower = [vision_tower] 63 | else: 64 | self.vision_tower = vision_tower 65 | else: 66 | if fsdp is not None and len(fsdp) > 0: 67 | vision_tower = self.vision_tower[0] 68 | else: 69 | vision_tower = self.vision_tower 70 | vision_tower.load_model() 71 | 72 | self.config.use_mm_proj = True 73 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 74 | self.config.mm_hidden_size = vision_tower.hidden_size 75 | self.config.mm_vision_select_layer = mm_vision_select_layer 76 | self.config.mm_vision_select_feature = mm_vision_select_feature 77 | self.config.mm_patch_merge_type = mm_patch_merge_type 78 | 79 | if getattr(self, 'mm_projector', None) is None: 80 | self.mm_projector = build_vision_projector(self.config) 81 | 82 | if 'unpad' in mm_patch_merge_type: 83 | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) 84 | self.image_newline = nn.Parameter( 85 | torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std 86 | ) 87 | else: 88 | # In case it is frozen by LoRA 89 | for p in self.mm_projector.parameters(): 90 | p.requires_grad = True 91 | 92 | if pretrain_mm_mlp_adapter is not None: 93 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 94 | def get_w(weights, keyword): 95 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 96 | 97 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 98 | 99 | 100 | def unpad_image(tensor, original_size): 101 | """ 102 | Unpads a PyTorch tensor of a padded and resized image. 103 | 104 | Args: 105 | tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. 106 | original_size (tuple): The original size of the image (height, width). 107 | 108 | Returns: 109 | torch.Tensor: The unpadded image tensor. 110 | """ 111 | original_width, original_height = original_size 112 | current_height, current_width = tensor.shape[1:] 113 | 114 | original_aspect_ratio = original_width / original_height 115 | current_aspect_ratio = current_width / current_height 116 | 117 | if original_aspect_ratio > current_aspect_ratio: 118 | scale_factor = current_width / original_width 119 | new_height = int(original_height * scale_factor) 120 | padding = (current_height - new_height) // 2 121 | unpadded_tensor = tensor[:, padding:current_height - padding, :] 122 | else: 123 | scale_factor = current_height / original_height 124 | new_width = int(original_width * scale_factor) 125 | padding = (current_width - new_width) // 2 126 | unpadded_tensor = tensor[:, :, padding:current_width - padding] 127 | 128 | return unpadded_tensor 129 | 130 | 131 | class LlavaMetaForCausalLM(ABC): 132 | 133 | @abstractmethod 134 | def get_model(self): 135 | pass 136 | 137 | def get_vision_tower(self): 138 | return self.get_model().get_vision_tower() 139 | 140 | def encode_images(self, images): 141 | image_features = self.get_model().get_vision_tower()(images) 142 | image_features = self.get_model().mm_projector(image_features) 143 | return image_features 144 | 145 | def prepare_inputs_labels_for_multimodal( 146 | self, input_ids, position_ids, attention_mask, past_key_values, labels, 147 | images, image_sizes=None 148 | ): 149 | vision_tower = self.get_vision_tower() 150 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 151 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 152 | 153 | if type(images) is list or images.ndim == 5: 154 | if type(images) is list: 155 | images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] 156 | concat_images = torch.cat([image for image in images], dim=0) 157 | image_features = self.encode_images(concat_images) 158 | split_sizes = [image.shape[0] for image in images] 159 | image_features = torch.split(image_features, split_sizes, dim=0) 160 | mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') 161 | image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') 162 | if mm_patch_merge_type == 'flat': 163 | image_features = [x.flatten(0, 1) for x in image_features] 164 | elif mm_patch_merge_type.startswith('spatial'): 165 | new_image_features = [] 166 | for image_idx, image_feature in enumerate(image_features): 167 | if image_feature.shape[0] > 1: 168 | base_image_feature = image_feature[0] 169 | image_feature = image_feature[1:] 170 | height = width = self.get_vision_tower().num_patches_per_side 171 | assert height * width == base_image_feature.shape[0] 172 | if image_aspect_ratio == 'anyres': 173 | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size) 174 | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) 175 | else: 176 | raise NotImplementedError 177 | if 'unpad' in mm_patch_merge_type: 178 | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() 179 | image_feature = image_feature.flatten(1, 2).flatten(2, 3) 180 | image_feature = unpad_image(image_feature, image_sizes[image_idx]) 181 | image_feature = torch.cat(( 182 | image_feature, 183 | self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) 184 | ), dim=-1) 185 | image_feature = image_feature.flatten(1, 2).transpose(0, 1) 186 | else: 187 | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() 188 | image_feature = image_feature.flatten(0, 3) 189 | image_feature = torch.cat((base_image_feature, image_feature), dim=0) 190 | else: 191 | image_feature = image_feature[0] 192 | if 'unpad' in mm_patch_merge_type: 193 | image_feature = torch.cat(( 194 | image_feature, 195 | self.model.image_newline[None].to(image_feature.device) 196 | ), dim=0) 197 | new_image_features.append(image_feature) 198 | image_features = new_image_features 199 | else: 200 | raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") 201 | else: 202 | image_features = self.encode_images(images) 203 | 204 | # TODO: image start / end is not implemented here to support pretraining. 205 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 206 | raise NotImplementedError 207 | 208 | # Let's just add dummy tensors if they do not exist, 209 | # it is a headache to deal with None all the time. 210 | # But it is not ideal, and if you have a better idea, 211 | # please open an issue / submit a PR, thanks. 212 | _labels = labels 213 | _position_ids = position_ids 214 | _attention_mask = attention_mask 215 | if attention_mask is None: 216 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 217 | else: 218 | attention_mask = attention_mask.bool() 219 | if position_ids is None: 220 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 221 | if labels is None: 222 | labels = torch.full_like(input_ids, IGNORE_INDEX) 223 | 224 | # remove the padding using attention_mask -- FIXME 225 | _input_ids = input_ids 226 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 227 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 228 | 229 | new_input_embeds = [] 230 | new_labels = [] 231 | cur_image_idx = 0 232 | for batch_idx, cur_input_ids in enumerate(input_ids): 233 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() 234 | if num_images == 0: 235 | cur_image_features = image_features[cur_image_idx] 236 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 237 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) 238 | new_input_embeds.append(cur_input_embeds) 239 | new_labels.append(labels[batch_idx]) 240 | cur_image_idx += 1 241 | continue 242 | 243 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] 244 | cur_input_ids_noim = [] 245 | cur_labels = labels[batch_idx] 246 | cur_labels_noim = [] 247 | for i in range(len(image_token_indices) - 1): 248 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) 249 | cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) 250 | split_sizes = [x.shape[0] for x in cur_labels_noim] 251 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) 252 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) 253 | cur_new_input_embeds = [] 254 | cur_new_labels = [] 255 | 256 | for i in range(num_images + 1): 257 | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) 258 | cur_new_labels.append(cur_labels_noim[i]) 259 | if i < num_images: 260 | cur_image_features = image_features[cur_image_idx] 261 | cur_image_idx += 1 262 | cur_new_input_embeds.append(cur_image_features) 263 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 264 | 265 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] 266 | 267 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) 268 | cur_new_labels = torch.cat(cur_new_labels) 269 | 270 | new_input_embeds.append(cur_new_input_embeds) 271 | new_labels.append(cur_new_labels) 272 | 273 | # Truncate sequences to max length as image embeddings can make the sequence longer 274 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 275 | if tokenizer_model_max_length is not None: 276 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 277 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 278 | 279 | # Combine them 280 | max_len = max(x.shape[0] for x in new_input_embeds) 281 | batch_size = len(new_input_embeds) 282 | 283 | new_input_embeds_padded = [] 284 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 285 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 286 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 287 | 288 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 289 | cur_len = cur_new_embed.shape[0] 290 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 291 | new_input_embeds_padded.append(torch.cat(( 292 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 293 | cur_new_embed 294 | ), dim=0)) 295 | if cur_len > 0: 296 | new_labels_padded[i, -cur_len:] = cur_new_labels 297 | attention_mask[i, -cur_len:] = True 298 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 299 | else: 300 | new_input_embeds_padded.append(torch.cat(( 301 | cur_new_embed, 302 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 303 | ), dim=0)) 304 | if cur_len > 0: 305 | new_labels_padded[i, :cur_len] = cur_new_labels 306 | attention_mask[i, :cur_len] = True 307 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 308 | 309 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 310 | 311 | if _labels is None: 312 | new_labels = None 313 | else: 314 | new_labels = new_labels_padded 315 | 316 | if _attention_mask is None: 317 | attention_mask = None 318 | else: 319 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 320 | 321 | if _position_ids is None: 322 | position_ids = None 323 | 324 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 325 | 326 | def initialize_vision_tokenizer(self, model_args, tokenizer): 327 | if model_args.mm_use_im_patch_token: 328 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 329 | self.resize_token_embeddings(len(tokenizer)) 330 | 331 | if model_args.mm_use_im_start_end: 332 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 333 | self.resize_token_embeddings(len(tokenizer)) 334 | 335 | if num_new_tokens > 0: 336 | input_embeddings = self.get_input_embeddings().weight.data 337 | output_embeddings = self.get_output_embeddings().weight.data 338 | 339 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 340 | dim=0, keepdim=True) 341 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 342 | dim=0, keepdim=True) 343 | 344 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 345 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 346 | 347 | if model_args.tune_mm_mlp_adapter: 348 | for p in self.get_input_embeddings().parameters(): 349 | p.requires_grad = True 350 | for p in self.get_output_embeddings().parameters(): 351 | p.requires_grad = False 352 | 353 | if model_args.pretrain_mm_mlp_adapter: 354 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 355 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 356 | assert num_new_tokens == 2 357 | if input_embeddings.shape == embed_tokens_weight.shape: 358 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 359 | elif embed_tokens_weight.shape[0] == num_new_tokens: 360 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 361 | else: 362 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 363 | elif model_args.mm_use_im_patch_token: 364 | if model_args.tune_mm_mlp_adapter: 365 | for p in self.get_input_embeddings().parameters(): 366 | p.requires_grad = False 367 | for p in self.get_output_embeddings().parameters(): 368 | p.requires_grad = False 369 | -------------------------------------------------------------------------------- /llava/train/train.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import copy 19 | from dataclasses import dataclass, field 20 | import itertools 21 | import json 22 | import logging 23 | import math 24 | import random 25 | from pathlib import Path 26 | from typing import Dict, Optional, Sequence, List 27 | 28 | import torch 29 | 30 | import transformers 31 | import tokenizers 32 | 33 | from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, FLOAT_TOKEN 34 | from torch.utils.data import Dataset 35 | from llava.train.llava_trainer import LLaVATrainer 36 | 37 | from llava import conversation as conversation_lib 38 | from llava.model import * 39 | from llava.mm_utils import tokenizer_image_token 40 | 41 | from PIL import Image, ImageDraw 42 | from scipy.spatial.transform import Rotation as R 43 | import numpy as np 44 | 45 | 46 | local_rank = None 47 | 48 | 49 | def rank0_print(*args): 50 | if local_rank == 0: 51 | print(*args) 52 | 53 | 54 | from packaging import version 55 | IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') 56 | 57 | 58 | @dataclass 59 | class ModelArguments: 60 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 61 | version: Optional[str] = field(default="v0") 62 | freeze_backbone: bool = field(default=False) 63 | tune_mm_mlp_adapter: bool = field(default=False) 64 | vision_tower: Optional[str] = field(default=None) 65 | mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer 66 | pretrain_mm_mlp_adapter: Optional[str] = field(default=None) 67 | mm_projector_type: Optional[str] = field(default='linear') 68 | mm_use_im_start_end: bool = field(default=False) 69 | mm_use_im_patch_token: bool = field(default=True) 70 | mm_patch_merge_type: Optional[str] = field(default='flat') 71 | mm_vision_select_feature: Optional[str] = field(default="patch") 72 | float_head_type: Optional[str] = field(default=None) 73 | float_w: float = field(default=1.0) 74 | 75 | 76 | @dataclass 77 | class DataArguments: 78 | data_path: str = field(default=None, 79 | metadata={"help": "Path to the training data."}) 80 | lazy_preprocess: bool = False 81 | is_multimodal: bool = False 82 | image_folder: Optional[str] = field(default=None) 83 | image_aspect_ratio: str = 'square' 84 | data_path_val: Optional[str] = field(default=None) 85 | image_folder_val: Optional[str] = field(default=None) 86 | use_synonyms: bool = field(default=True) 87 | shuffle_attributes: bool = field(default=True) 88 | num_samples: Optional[int] = field(default=None) 89 | is_2d: bool = field(default=False) 90 | rotation_rep: Optional[str] = field(default=None) 91 | 92 | 93 | @dataclass 94 | class TrainingArguments(transformers.TrainingArguments): 95 | cache_dir: Optional[str] = field(default=None) 96 | optim: str = field(default="adamw_torch") 97 | remove_unused_columns: bool = field(default=False) 98 | freeze_mm_mlp_adapter: bool = field(default=False) 99 | mpt_attn_impl: Optional[str] = field(default="triton") 100 | model_max_length: int = field( 101 | default=512, 102 | metadata={ 103 | "help": 104 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 105 | }, 106 | ) 107 | double_quant: bool = field( 108 | default=True, 109 | metadata={"help": "Compress the quantization statistics through double quantization."} 110 | ) 111 | quant_type: str = field( 112 | default="nf4", 113 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} 114 | ) 115 | bits: int = field( 116 | default=16, 117 | metadata={"help": "How many bits to use."} 118 | ) 119 | lora_enable: bool = False 120 | lora_r: int = 64 121 | lora_alpha: int = 16 122 | lora_dropout: float = 0.05 123 | lora_weight_path: str = "" 124 | lora_bias: str = "none" 125 | mm_projector_lr: Optional[float] = None 126 | group_by_modality_length: bool = field(default=False) 127 | float_head_lr: Optional[float] = field(default=None) 128 | 129 | 130 | def maybe_zero_3(param, ignore_status=False, name=None): 131 | from deepspeed import zero 132 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 133 | if hasattr(param, "ds_id"): 134 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 135 | if not ignore_status: 136 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 137 | with zero.GatheredParameters([param]): 138 | param = param.data.detach().cpu().clone() 139 | else: 140 | param = param.detach().cpu().clone() 141 | return param 142 | 143 | 144 | # Borrowed from peft.utils.get_peft_model_state_dict 145 | def get_peft_state_maybe_zero_3(named_params, bias): 146 | if bias == "none": 147 | to_return = {k: t for k, t in named_params if "lora_" in k} 148 | elif bias == "all": 149 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 150 | elif bias == "lora_only": 151 | to_return = {} 152 | maybe_lora_bias = {} 153 | lora_bias_names = set() 154 | for k, t in named_params: 155 | if "lora_" in k: 156 | to_return[k] = t 157 | bias_name = k.split("lora_")[0] + "bias" 158 | lora_bias_names.add(bias_name) 159 | elif "bias" in k: 160 | maybe_lora_bias[k] = t 161 | for k, t in maybe_lora_bias: 162 | if bias_name in lora_bias_names: 163 | to_return[bias_name] = t 164 | else: 165 | raise NotImplementedError 166 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} 167 | return to_return 168 | 169 | 170 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 171 | to_return = {k: t for k, t in named_params if "lora_" not in k} 172 | if require_grad_only: 173 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 174 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 175 | return to_return 176 | 177 | 178 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 179 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 180 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 181 | return to_return 182 | 183 | 184 | def find_all_linear_names(model): 185 | cls = torch.nn.Linear 186 | lora_module_names = set() 187 | multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] 188 | for name, module in model.named_modules(): 189 | if any(mm_keyword in name for mm_keyword in multimodal_keywords): 190 | continue 191 | if isinstance(module, cls) and 'float_head' not in name: 192 | names = name.split('.') 193 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 194 | 195 | if 'lm_head' in lora_module_names: # needed for 16-bit 196 | lora_module_names.remove('lm_head') 197 | return list(lora_module_names) 198 | 199 | 200 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 201 | output_dir: str): 202 | """Collects the state dict and dump to disk.""" 203 | 204 | if getattr(trainer.args, "tune_mm_mlp_adapter", False): 205 | # Only save Adapter 206 | keys_to_match = ['mm_projector'] 207 | if getattr(trainer.args, "use_im_start_end", False): 208 | keys_to_match.extend(['embed_tokens', 'embed_in']) 209 | 210 | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) 211 | trainer.model.config.save_pretrained(output_dir) 212 | 213 | current_folder = output_dir.split('/')[-1] 214 | parent_folder = os.path.dirname(output_dir) 215 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: 216 | if current_folder.startswith('checkpoint-'): 217 | mm_projector_folder = os.path.join(parent_folder, "mm_projector") 218 | os.makedirs(mm_projector_folder, exist_ok=True) 219 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) 220 | else: 221 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 222 | return 223 | 224 | if trainer.deepspeed: 225 | torch.cuda.synchronize() 226 | trainer.save_model(output_dir) 227 | return 228 | 229 | state_dict = trainer.model.state_dict() 230 | if trainer.args.should_save: 231 | cpu_state_dict = { 232 | key: value.cpu() 233 | for key, value in state_dict.items() 234 | } 235 | del state_dict 236 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 237 | 238 | 239 | def smart_tokenizer_and_embedding_resize( 240 | special_tokens_dict: Dict, 241 | tokenizer: transformers.PreTrainedTokenizer, 242 | model: transformers.PreTrainedModel, 243 | ): 244 | """Resize tokenizer and embedding. 245 | 246 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 247 | """ 248 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 249 | model.resize_token_embeddings(len(tokenizer)) 250 | 251 | if num_new_tokens > 0: 252 | input_embeddings = model.get_input_embeddings().weight.data 253 | output_embeddings = model.get_output_embeddings().weight.data 254 | 255 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 256 | dim=0, keepdim=True) 257 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 258 | dim=0, keepdim=True) 259 | 260 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 261 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 262 | 263 | 264 | def _tokenize_fn(strings: Sequence[str], 265 | tokenizer: transformers.PreTrainedTokenizer) -> Dict: 266 | """Tokenize a list of strings.""" 267 | tokenized_list = [ 268 | tokenizer( 269 | text, 270 | return_tensors="pt", 271 | padding="longest", 272 | max_length=tokenizer.model_max_length, 273 | truncation=True, 274 | ) for text in strings 275 | ] 276 | input_ids = labels = [ 277 | tokenized.input_ids[0] for tokenized in tokenized_list 278 | ] 279 | input_ids_lens = labels_lens = [ 280 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 281 | for tokenized in tokenized_list 282 | ] 283 | return dict( 284 | input_ids=input_ids, 285 | labels=labels, 286 | input_ids_lens=input_ids_lens, 287 | labels_lens=labels_lens, 288 | ) 289 | 290 | 291 | def _mask_targets(target, tokenized_lens, speakers): 292 | # cur_idx = 0 293 | cur_idx = tokenized_lens[0] 294 | tokenized_lens = tokenized_lens[1:] 295 | target[:cur_idx] = IGNORE_INDEX 296 | for tokenized_len, speaker in zip(tokenized_lens, speakers): 297 | if speaker == "human": 298 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX 299 | cur_idx += tokenized_len 300 | 301 | 302 | def _add_speaker_and_signal(header, source, get_conversation=True): 303 | """Add speaker and start/end signal on each round.""" 304 | BEGIN_SIGNAL = "### " 305 | END_SIGNAL = "\n" 306 | conversation = header 307 | for sentence in source: 308 | from_str = sentence["from"] 309 | if from_str.lower() == "human": 310 | from_str = conversation_lib.default_conversation.roles[0] 311 | elif from_str.lower() == "gpt": 312 | from_str = conversation_lib.default_conversation.roles[1] 313 | else: 314 | from_str = 'unknown' 315 | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + 316 | sentence["value"] + END_SIGNAL) 317 | if get_conversation: 318 | conversation += sentence["value"] 319 | conversation += BEGIN_SIGNAL 320 | return conversation 321 | 322 | 323 | def preprocess_multimodal( 324 | sources: Sequence[str], 325 | data_args: DataArguments 326 | ) -> Dict: 327 | is_multimodal = data_args.is_multimodal 328 | if not is_multimodal: 329 | return sources 330 | 331 | for source in sources: 332 | for sentence in source: 333 | if DEFAULT_IMAGE_TOKEN in sentence['value']: 334 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 335 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] 336 | sentence['value'] = sentence['value'].strip() 337 | if "mmtag" in conversation_lib.default_conversation.version: 338 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') 339 | replace_token = DEFAULT_IMAGE_TOKEN 340 | if data_args.mm_use_im_start_end: 341 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 342 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 343 | 344 | return sources 345 | 346 | 347 | def preprocess_llama_2( 348 | sources, 349 | tokenizer: transformers.PreTrainedTokenizer, 350 | has_image: bool = False 351 | ) -> Dict: 352 | conv = conversation_lib.default_conversation.copy() 353 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 354 | 355 | # Apply prompt templates 356 | conversations = [] 357 | for i, source in enumerate(sources): 358 | if roles[source[0]["from"]] != conv.roles[0]: 359 | # Skip the first one if it is not from human 360 | source = source[1:] 361 | 362 | conv.messages = [] 363 | for j, sentence in enumerate(source): 364 | role = roles[sentence["from"]] 365 | assert role == conv.roles[j % 2], f"{i}" 366 | conv.append_message(role, sentence["value"]) 367 | conversations.append(conv.get_prompt()) 368 | 369 | # Tokenize conversations 370 | 371 | if has_image: 372 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 373 | else: 374 | input_ids = tokenizer( 375 | conversations, 376 | return_tensors="pt", 377 | padding="longest", 378 | max_length=tokenizer.model_max_length, 379 | truncation=True, 380 | ).input_ids 381 | 382 | targets = input_ids.clone() 383 | 384 | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 385 | 386 | # Mask targets 387 | sep = "[/INST] " 388 | for conversation, target in zip(conversations, targets): 389 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 390 | 391 | rounds = conversation.split(conv.sep2) 392 | cur_len = 1 393 | target[:cur_len] = IGNORE_INDEX 394 | for i, rou in enumerate(rounds): 395 | if rou == "": 396 | break 397 | 398 | parts = rou.split(sep) 399 | if len(parts) != 2: 400 | break 401 | parts[0] += sep 402 | 403 | if has_image: 404 | round_len = len(tokenizer_image_token(rou, tokenizer)) 405 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 406 | else: 407 | round_len = len(tokenizer(rou).input_ids) 408 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 409 | 410 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 411 | 412 | cur_len += round_len 413 | target[cur_len:] = IGNORE_INDEX 414 | 415 | if cur_len < tokenizer.model_max_length: 416 | if cur_len != total_len: 417 | target[:] = IGNORE_INDEX 418 | print( 419 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 420 | f" (ignored)" 421 | ) 422 | 423 | return dict( 424 | input_ids=input_ids, 425 | labels=targets, 426 | ) 427 | 428 | 429 | def preprocess_v1( 430 | sources, 431 | tokenizer: transformers.PreTrainedTokenizer, 432 | has_image: bool = False 433 | ) -> Dict: 434 | conv = conversation_lib.default_conversation.copy() 435 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 436 | 437 | # Apply prompt templates 438 | conversations = [] 439 | for i, source in enumerate(sources): 440 | if roles[source[0]["from"]] != conv.roles[0]: 441 | # Skip the first one if it is not from human 442 | source = source[1:] 443 | 444 | conv.messages = [] 445 | for j, sentence in enumerate(source): 446 | role = roles[sentence["from"]] 447 | assert role == conv.roles[j % 2], f"{i}" 448 | conv.append_message(role, sentence["value"]) 449 | conversations.append(conv.get_prompt()) 450 | 451 | # Tokenize conversations 452 | 453 | if has_image: 454 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 455 | else: 456 | input_ids = tokenizer( 457 | conversations, 458 | return_tensors="pt", 459 | padding="longest", 460 | max_length=tokenizer.model_max_length, 461 | truncation=True, 462 | ).input_ids 463 | 464 | targets = input_ids.clone() 465 | 466 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 467 | 468 | # Mask targets 469 | sep = conv.sep + conv.roles[1] + ": " 470 | for conversation, target in zip(conversations, targets): 471 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 472 | 473 | rounds = conversation.split(conv.sep2) 474 | cur_len = 1 475 | target[:cur_len] = IGNORE_INDEX 476 | for i, rou in enumerate(rounds): 477 | if rou == "": 478 | break 479 | 480 | parts = rou.split(sep) 481 | if len(parts) != 2: 482 | break 483 | parts[0] += sep 484 | 485 | if has_image: 486 | round_len = len(tokenizer_image_token(rou, tokenizer)) 487 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 488 | else: 489 | round_len = len(tokenizer(rou).input_ids) 490 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 491 | 492 | if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: 493 | round_len -= 1 494 | instruction_len -= 1 495 | 496 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 497 | 498 | cur_len += round_len 499 | target[cur_len:] = IGNORE_INDEX 500 | 501 | if cur_len < tokenizer.model_max_length: 502 | if cur_len != total_len: 503 | target[:] = IGNORE_INDEX 504 | print( 505 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 506 | f" (ignored)" 507 | ) 508 | 509 | return dict( 510 | input_ids=input_ids, 511 | labels=targets, 512 | ) 513 | 514 | 515 | def preprocess_mpt( 516 | sources, 517 | tokenizer: transformers.PreTrainedTokenizer, 518 | has_image: bool = False 519 | ) -> Dict: 520 | conv = conversation_lib.default_conversation.copy() 521 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 522 | 523 | # Apply prompt templates 524 | conversations = [] 525 | for i, source in enumerate(sources): 526 | if roles[source[0]["from"]] != conv.roles[0]: 527 | # Skip the first one if it is not from human 528 | source = source[1:] 529 | 530 | conv.messages = [] 531 | for j, sentence in enumerate(source): 532 | role = roles[sentence["from"]] 533 | assert role == conv.roles[j % 2], f"{i}" 534 | conv.append_message(role, sentence["value"]) 535 | conversations.append(conv.get_prompt()) 536 | 537 | # Tokenize conversations 538 | 539 | if has_image: 540 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 541 | else: 542 | input_ids = tokenizer( 543 | conversations, 544 | return_tensors="pt", 545 | padding="longest", 546 | max_length=tokenizer.model_max_length, 547 | truncation=True, 548 | ).input_ids 549 | 550 | targets = input_ids.clone() 551 | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT 552 | 553 | # Mask targets 554 | sep = conv.sep + conv.roles[1] 555 | for conversation, target in zip(conversations, targets): 556 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 557 | 558 | rounds = conversation.split(conv.sep) 559 | re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt 560 | for conv_idx in range(3, len(rounds), 2): 561 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt 562 | cur_len = 0 563 | target[:cur_len] = IGNORE_INDEX 564 | for i, rou in enumerate(re_rounds): 565 | if rou == "": 566 | break 567 | 568 | parts = rou.split(sep) 569 | if len(parts) != 2: 570 | break 571 | parts[0] += sep 572 | 573 | if has_image: 574 | round_len = len(tokenizer_image_token(rou, tokenizer)) 575 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 576 | else: 577 | round_len = len(tokenizer(rou).input_ids) 578 | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 579 | 580 | if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: 581 | round_len += 1 582 | instruction_len += 1 583 | 584 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 585 | 586 | cur_len += round_len 587 | target[cur_len:] = IGNORE_INDEX 588 | 589 | if cur_len < tokenizer.model_max_length: 590 | if cur_len != total_len: 591 | target[:] = IGNORE_INDEX 592 | print( 593 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 594 | f" (ignored)" 595 | ) 596 | 597 | return dict( 598 | input_ids=input_ids, 599 | labels=targets, 600 | ) 601 | 602 | 603 | def preprocess_plain( 604 | sources: Sequence[str], 605 | tokenizer: transformers.PreTrainedTokenizer, 606 | ) -> Dict: 607 | # add end signal and concatenate together 608 | conversations = [] 609 | for source in sources: 610 | assert len(source) == 2 611 | assert DEFAULT_IMAGE_TOKEN in source[0]['value'] 612 | source[0]['value'] = DEFAULT_IMAGE_TOKEN 613 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep 614 | conversations.append(conversation) 615 | # tokenize conversations 616 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 617 | targets = copy.deepcopy(input_ids) 618 | for target, source in zip(targets, sources): 619 | tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) 620 | target[:tokenized_len] = IGNORE_INDEX 621 | 622 | return dict(input_ids=input_ids, labels=targets) 623 | 624 | 625 | def preprocess( 626 | sources: Sequence[str], 627 | tokenizer: transformers.PreTrainedTokenizer, 628 | has_image: bool = False 629 | ) -> Dict: 630 | """ 631 | Given a list of sources, each is a conversation list. This transform: 632 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 633 | 2. Concatenate conversations together; 634 | 3. Tokenize the concatenated conversation; 635 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 636 | """ 637 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: 638 | return preprocess_plain(sources, tokenizer) 639 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: 640 | return preprocess_llama_2(sources, tokenizer, has_image=has_image) 641 | if conversation_lib.default_conversation.version.startswith("v1"): 642 | return preprocess_v1(sources, tokenizer, has_image=has_image) 643 | if conversation_lib.default_conversation.version == "mpt": 644 | return preprocess_mpt(sources, tokenizer, has_image=has_image) 645 | # add end signal and concatenate together 646 | conversations = [] 647 | for source in sources: 648 | header = f"{conversation_lib.default_conversation.system}\n\n" 649 | conversation = _add_speaker_and_signal(header, source) 650 | conversations.append(conversation) 651 | # tokenize conversations 652 | def get_tokenize_len(prompts): 653 | return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] 654 | 655 | if has_image: 656 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 657 | else: 658 | conversations_tokenized = _tokenize_fn(conversations, tokenizer) 659 | input_ids = conversations_tokenized["input_ids"] 660 | 661 | targets = copy.deepcopy(input_ids) 662 | for target, source in zip(targets, sources): 663 | if has_image: 664 | tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) 665 | else: 666 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] 667 | speakers = [sentence["from"] for sentence in source] 668 | _mask_targets(target, tokenized_lens, speakers) 669 | 670 | return dict(input_ids=input_ids, labels=targets) 671 | 672 | 673 | def get_synonym_clevr(n): # From Johnson et al. (2017) 674 | return random.choice( 675 | { 676 | "sphere": ["sphere", "ball"], 677 | "cube": ["cube", "block"], 678 | "large": ["large", "big"], 679 | "small": ["small", "tiny"], 680 | "metal": ["metallic", "metal", "shiny"], 681 | "rubber": ["rubber", "matte"], 682 | }.get(n, [n]) 683 | ) 684 | 685 | 686 | def draw_dot(x, y, width=336, height=336, r=5) -> Image.Image: 687 | x *= width 688 | y *= height 689 | image = Image.new("RGB", (width, height), color=(255, 255, 255)) 690 | ImageDraw.Draw(image).ellipse([(x - r, y - r), (x + r, y + r)], fill="red") 691 | return image 692 | 693 | 694 | class LazySupervisedDataset(Dataset): 695 | """Dataset for supervised fine-tuning.""" 696 | 697 | def __init__(self, 698 | tokenizer: transformers.PreTrainedTokenizer, 699 | data_args: DataArguments, 700 | train: bool = True, 701 | ): 702 | super(LazySupervisedDataset, self).__init__() 703 | 704 | if train: 705 | data_path = data_args.data_path 706 | self.image_folder = data_args.image_folder 707 | else: 708 | data_path = data_args.data_path_val 709 | self.image_folder = data_args.image_folder_val 710 | 711 | if data_args.is_2d: 712 | list_data_dict = np.load('data/2d.npz')[data_path] 713 | else: 714 | list_data_dict = json.loads(Path(data_path).read_bytes()) 715 | if 'scenes' in list_data_dict: 716 | list_data_dict = list_data_dict['scenes'] 717 | 718 | if data_args.num_samples: 719 | list_data_dict = list_data_dict[:data_args.num_samples] 720 | 721 | rank0_print("Formatting inputs...Skip in lazy mode") 722 | self.tokenizer = tokenizer 723 | self.list_data_dict = list_data_dict 724 | self.data_args = data_args 725 | 726 | if data_args.use_synonyms: 727 | self.synonym_f = get_synonym_clevr 728 | else: 729 | self.synonym_f = lambda x: x 730 | 731 | def __len__(self): 732 | return len(self.list_data_dict) 733 | 734 | @property 735 | def lengths(self): 736 | length_list = [] 737 | for sample in self.list_data_dict: 738 | img_tokens = 128 if 'image' in sample else 0 739 | length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) 740 | return length_list 741 | 742 | @property 743 | def modality_lengths(self): 744 | length_list = [] 745 | for sample in self.list_data_dict: 746 | cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) 747 | cur_len = cur_len if 'image' in sample else -cur_len 748 | length_list.append(cur_len) 749 | return length_list 750 | 751 | def prep_code(self, entry): 752 | code = "" 753 | floats = [] 754 | 755 | for obj in sorted(entry["objects"], key=lambda x: x["pixel_coords"][0]): 756 | attrs = [] 757 | if "size" in obj: 758 | attrs.append(("size='{}'".format(self.synonym_f(obj["size"])), [])) 759 | if "color" in obj: 760 | attrs.append(("color='{}'".format(self.synonym_f(obj["color"])), [])) 761 | if "material" in obj: 762 | attrs.append( 763 | ("material='{}'".format(self.synonym_f(obj["material"])), []) 764 | ) 765 | if "shape" in obj: 766 | attrs.append(("shape='{}'".format(self.synonym_f(obj["shape"])), [])) 767 | if "3d_coords" in obj: 768 | attrs.append( 769 | ( 770 | f"loc=({FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN})", 771 | obj["3d_coords"][:], 772 | ) 773 | ) 774 | if "rotation" in obj and obj.get("shape") not in {"cylinder", "sphere"}: 775 | rotation = obj["rotation"] 776 | if obj.get("shape") == "cube": # Convert degrees to radians and modulo 777 | rotation = 2 * math.pi * (rotation % 90 - 45) / 360 778 | if self.data_args.rotation_rep != None: 779 | if self.data_args.rotation_rep == "6d": 780 | attrs.append( 781 | ( 782 | f"rotation=({FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN})", 783 | R.from_euler("xyz", rotation).as_matrix()[:2].flatten().tolist() 784 | ) 785 | ) 786 | elif self.data_args.rotation_rep == "euler": 787 | attrs.append( 788 | ( 789 | f"rotation=({FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN})", 790 | rotation, 791 | ) 792 | ) 793 | elif self.data_args.rotation_rep == "euler_int": 794 | attrs.append( 795 | ( 796 | f"rotation=({FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN})", 797 | R.from_euler("xyz", rotation).as_euler("XYZ").tolist(), 798 | ) 799 | ) 800 | elif self.data_args.rotation_rep == "aa": 801 | attrs.append( 802 | ( 803 | f"rotation=({FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN})", 804 | R.from_euler("xyz", rotation).as_rotvec().tolist(), 805 | ) 806 | ) 807 | elif self.data_args.rotation_rep == "quat": 808 | attrs.append( 809 | ( 810 | f"rotation=({FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN}, {FLOAT_TOKEN})", 811 | R.from_euler("xyz", rotation).as_quat().tolist(), 812 | ) 813 | ) 814 | else: 815 | raise NotImplementedError( 816 | f"Unknown rotation type {self.data_args.rotation_rep}" 817 | ) 818 | else: 819 | attrs.append((f"rotation={FLOAT_TOKEN}", [rotation])) 820 | if self.data_args.shuffle_attributes: 821 | random.shuffle(attrs) 822 | floats += list(itertools.chain(*[n[1] for n in attrs])) 823 | code += "add(" + ", ".join([n[0] for n in attrs]) + ")\n" 824 | 825 | if self.data_args.is_float: 826 | return code, floats 827 | 828 | code = code.replace(FLOAT_TOKEN, "{:+.3f}").format(*floats) 829 | 830 | return code, [] 831 | 832 | def __getitem__(self, i, return_raw=False) -> Dict[str, torch.Tensor]: 833 | sources = self.list_data_dict[i] 834 | has_image = True 835 | if self.data_args.is_2d: 836 | x, y = sources 837 | if self.data_args.is_float: 838 | if not self.data_args.shuffle_attributes or random.random() < 0.5: 839 | floats = [x, y] 840 | code = f'add(x={FLOAT_TOKEN}, y={FLOAT_TOKEN})\n' 841 | else: 842 | floats = [y, x] 843 | code = f'add(y={FLOAT_TOKEN}, x={FLOAT_TOKEN})\n' 844 | else: 845 | floats = [] 846 | if not self.data_args.shuffle_attributes or random.random() < 0.5: 847 | code = f'add(x={x:.3f}, y={y:.3f})\n' 848 | else: 849 | code = f'add(y={y:.3f}, x={x:.3f})\n' 850 | image_name = f'{i}.png' 851 | else: 852 | code, floats = self.prep_code(sources) 853 | image_name = sources['image_filename'] 854 | sources = { 855 | "id": i, 856 | "image": image_name, 857 | "conversations": [ 858 | { 859 | "from": "human", 860 | "value": "\nWhat Python Blender code could be used to produce the scene?", 861 | }, 862 | { 863 | "from": "gpt", 864 | "value": f"The scene can be produced with:\n```python\n{code}```", 865 | "floats": floats 866 | } 867 | ] 868 | } 869 | conversations = sources['conversations'] 870 | if isinstance(i, int): 871 | sources = [sources] 872 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 873 | if has_image: 874 | image_file = sources[0]['image'] 875 | image_folder = self.image_folder 876 | processor = self.data_args.image_processor 877 | if self.data_args.is_2d: 878 | image = draw_dot(x, y).convert('RGB') 879 | else: 880 | image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 881 | if self.data_args.image_aspect_ratio == 'pad': 882 | def expand2square(pil_img, background_color): 883 | width, height = pil_img.size 884 | if width == height: 885 | return pil_img 886 | elif width > height: 887 | result = Image.new(pil_img.mode, (width, width), background_color) 888 | result.paste(pil_img, (0, (width - height) // 2)) 889 | return result 890 | else: 891 | result = Image.new(pil_img.mode, (height, height), background_color) 892 | result.paste(pil_img, ((height - width) // 2, 0)) 893 | return result 894 | image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) 895 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 896 | else: 897 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 898 | sources = preprocess_multimodal( 899 | copy.deepcopy([e["conversations"] for e in sources]), 900 | self.data_args) 901 | else: 902 | sources = copy.deepcopy([e["conversations"] for e in sources]) 903 | data_dict = preprocess( 904 | sources, 905 | self.tokenizer, 906 | has_image=has_image) 907 | if isinstance(i, int): 908 | data_dict = dict(input_ids=data_dict["input_ids"][0], 909 | labels=data_dict["labels"][0]) 910 | 911 | # image exist in the data 912 | if has_image: 913 | data_dict['image'] = image 914 | elif self.data_args.is_multimodal: 915 | # image does not exist in the data, but the model is multimodal 916 | crop_size = self.data_args.image_processor.crop_size 917 | data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 918 | 919 | if 'floats' in sources[0][-1]: 920 | data_dict['floats'] = sources[0][-1]['floats'] 921 | 922 | if return_raw: 923 | data_dict['conversations'] = conversations 924 | 925 | return data_dict 926 | 927 | 928 | @dataclass 929 | class DataCollatorForSupervisedDataset(object): 930 | """Collate examples for supervised fine-tuning.""" 931 | 932 | tokenizer: transformers.PreTrainedTokenizer 933 | 934 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 935 | input_ids, labels = tuple([instance[key] for instance in instances] 936 | for key in ("input_ids", "labels")) 937 | input_ids = torch.nn.utils.rnn.pad_sequence( 938 | input_ids, 939 | batch_first=True, 940 | padding_value=self.tokenizer.pad_token_id) 941 | labels = torch.nn.utils.rnn.pad_sequence(labels, 942 | batch_first=True, 943 | padding_value=IGNORE_INDEX) 944 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 945 | labels = labels[:, :self.tokenizer.model_max_length] 946 | batch = dict( 947 | input_ids=input_ids, 948 | labels=labels, 949 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 950 | ) 951 | 952 | if 'image' in instances[0]: 953 | images = [instance['image'] for instance in instances] 954 | if all(x is not None and x.shape == images[0].shape for x in images): 955 | batch['images'] = torch.stack(images) 956 | else: 957 | batch['images'] = images 958 | 959 | if "floats" in instances[0]: 960 | batch["floats"] = torch.tensor([n for instance in instances for n in instance["floats"]]) 961 | 962 | return batch 963 | 964 | 965 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, 966 | data_args) -> Dict: 967 | """Make dataset and collator for supervised fine-tuning.""" 968 | train_dataset = LazySupervisedDataset(tokenizer=tokenizer, 969 | data_args=data_args, 970 | train=True) 971 | if data_args.data_path_val: 972 | eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, 973 | data_args=data_args, 974 | train=False) 975 | else: 976 | eval_dataset = None 977 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 978 | return dict(train_dataset=train_dataset, 979 | eval_dataset=eval_dataset, 980 | data_collator=data_collator) 981 | 982 | 983 | def train(attn_implementation=None): 984 | global local_rank 985 | 986 | parser = transformers.HfArgumentParser( 987 | (ModelArguments, DataArguments, TrainingArguments)) 988 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 989 | local_rank = training_args.local_rank 990 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 991 | data_args.is_float = model_args.float_head_type.lower() != "none" 992 | 993 | bnb_model_from_pretrained_args = { 994 | 'device_map': {"": training_args.device}, 995 | 'low_cpu_mem_usage': True, 996 | 'offload_state_dict': True, 997 | } 998 | if training_args.bits in [4, 8]: 999 | from transformers import BitsAndBytesConfig 1000 | bnb_model_from_pretrained_args.update(dict( 1001 | device_map={"": training_args.device}, 1002 | load_in_4bit=training_args.bits == 4, 1003 | load_in_8bit=training_args.bits == 8, 1004 | quantization_config=BitsAndBytesConfig( 1005 | load_in_4bit=training_args.bits == 4, 1006 | load_in_8bit=training_args.bits == 8, 1007 | llm_int8_skip_modules=["mm_projector"], 1008 | llm_int8_threshold=6.0, 1009 | llm_int8_has_fp16_weight=False, 1010 | bnb_4bit_compute_dtype=compute_dtype, 1011 | bnb_4bit_use_double_quant=training_args.double_quant, 1012 | bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} 1013 | ) 1014 | )) 1015 | 1016 | if model_args.vision_tower is not None: 1017 | if 'mpt' in model_args.model_name_or_path: 1018 | config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 1019 | config.attn_config['attn_impl'] = training_args.mpt_attn_impl 1020 | model = LlavaMptForCausalLM.from_pretrained( 1021 | model_args.model_name_or_path, 1022 | config=config, 1023 | cache_dir=training_args.cache_dir, 1024 | **bnb_model_from_pretrained_args 1025 | ) 1026 | else: 1027 | config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 1028 | config.float_head_type = model_args.float_head_type 1029 | config.float_w = model_args.float_w 1030 | 1031 | model = LlavaLlamaForCausalLM.from_pretrained( 1032 | model_args.model_name_or_path, 1033 | config=config, 1034 | cache_dir=training_args.cache_dir, 1035 | # attn_implementation=attn_implementation, 1036 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 1037 | **bnb_model_from_pretrained_args 1038 | ) 1039 | else: 1040 | model = transformers.LlamaForCausalLM.from_pretrained( 1041 | model_args.model_name_or_path, 1042 | cache_dir=training_args.cache_dir, 1043 | attn_implementation=attn_implementation, 1044 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 1045 | **bnb_model_from_pretrained_args 1046 | ) 1047 | model.config.use_cache = False 1048 | 1049 | if model_args.freeze_backbone: 1050 | model.model.requires_grad_(False) 1051 | 1052 | if training_args.bits in [4, 8]: 1053 | from peft import prepare_model_for_kbit_training 1054 | model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 1055 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) 1056 | 1057 | if training_args.gradient_checkpointing: 1058 | if hasattr(model, "enable_input_require_grads"): 1059 | model.enable_input_require_grads() 1060 | else: 1061 | def make_inputs_require_grad(module, input, output): 1062 | output.requires_grad_(True) 1063 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 1064 | 1065 | if training_args.lora_enable: 1066 | target_modules = find_all_linear_names(model) 1067 | if local_rank in {-1, 0}: 1068 | print(f'{target_modules=}') 1069 | from peft import LoraConfig, get_peft_model 1070 | lora_config = LoraConfig( 1071 | r=training_args.lora_r, 1072 | lora_alpha=training_args.lora_alpha, 1073 | target_modules=find_all_linear_names(model), 1074 | lora_dropout=training_args.lora_dropout, 1075 | bias=training_args.lora_bias, 1076 | task_type="CAUSAL_LM", 1077 | ) 1078 | if training_args.bits == 16: 1079 | if training_args.bf16: 1080 | model.to(torch.bfloat16) 1081 | if training_args.fp16: 1082 | model.to(torch.float16) 1083 | rank0_print("Adding LoRA adapters...") 1084 | model = get_peft_model(model, lora_config) 1085 | 1086 | if 'mpt' in model_args.model_name_or_path: 1087 | tokenizer = transformers.AutoTokenizer.from_pretrained( 1088 | model_args.model_name_or_path, 1089 | cache_dir=training_args.cache_dir, 1090 | model_max_length=training_args.model_max_length, 1091 | padding_side="right" 1092 | ) 1093 | else: 1094 | tokenizer = transformers.AutoTokenizer.from_pretrained( 1095 | model_args.model_name_or_path, 1096 | cache_dir=training_args.cache_dir, 1097 | model_max_length=training_args.model_max_length, 1098 | padding_side="right", 1099 | use_fast=False, 1100 | ) 1101 | 1102 | if model_args.version == "v0": 1103 | if tokenizer.pad_token is None: 1104 | smart_tokenizer_and_embedding_resize( 1105 | special_tokens_dict=dict(pad_token="[PAD]"), 1106 | tokenizer=tokenizer, 1107 | model=model, 1108 | ) 1109 | elif model_args.version == "v0.5": 1110 | tokenizer.pad_token = tokenizer.unk_token 1111 | else: 1112 | tokenizer.pad_token = tokenizer.unk_token 1113 | if model_args.version in conversation_lib.conv_templates: 1114 | conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] 1115 | else: 1116 | conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] 1117 | 1118 | if model_args.vision_tower is not None: 1119 | model.get_model().initialize_vision_modules( 1120 | model_args=model_args, 1121 | fsdp=training_args.fsdp 1122 | ) 1123 | 1124 | vision_tower = model.get_vision_tower() 1125 | vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) 1126 | 1127 | data_args.image_processor = vision_tower.image_processor 1128 | data_args.is_multimodal = True 1129 | 1130 | model.config.image_aspect_ratio = data_args.image_aspect_ratio 1131 | model.config.tokenizer_padding_side = tokenizer.padding_side 1132 | model.config.tokenizer_model_max_length = tokenizer.model_max_length 1133 | 1134 | model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter 1135 | if model_args.tune_mm_mlp_adapter: 1136 | model.requires_grad_(False) 1137 | for p in model.get_model().mm_projector.parameters(): 1138 | p.requires_grad = True 1139 | 1140 | model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter 1141 | if training_args.freeze_mm_mlp_adapter: 1142 | for p in model.get_model().mm_projector.parameters(): 1143 | p.requires_grad = False 1144 | 1145 | if training_args.bits in [4, 8]: 1146 | model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) 1147 | 1148 | model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end 1149 | model.config.mm_projector_lr = training_args.mm_projector_lr 1150 | training_args.use_im_start_end = model_args.mm_use_im_start_end 1151 | model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token 1152 | model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) 1153 | 1154 | if training_args.bits in [4, 8]: 1155 | from peft.tuners.lora import LoraLayer 1156 | for name, module in model.named_modules(): 1157 | if isinstance(module, LoraLayer): 1158 | if training_args.bf16: 1159 | module = module.to(torch.bfloat16) 1160 | if 'norm' in name: 1161 | module = module.to(torch.float32) 1162 | if 'lm_head' in name or 'embed_tokens' in name: 1163 | if hasattr(module, 'weight'): 1164 | if training_args.bf16 and module.weight.dtype == torch.float32: 1165 | module = module.to(torch.bfloat16) 1166 | 1167 | data_module = make_supervised_data_module(tokenizer=tokenizer, 1168 | data_args=data_args) 1169 | 1170 | if model.float_head is not None: 1171 | for p in model.float_head.parameters(): 1172 | p.requires_grad = True 1173 | if local_rank in {-1, 0}: 1174 | print('train_dataset', data_module['train_dataset'].__getitem__(0, return_raw=True)) 1175 | if data_module['eval_dataset']: 1176 | print('eval_dataset', data_module['eval_dataset'].__getitem__(0, return_raw=True)) 1177 | print(f'{get_peft_state_non_lora_maybe_zero_3(model.named_parameters()).keys()=}') 1178 | 1179 | trainer = LLaVATrainer(model=model, 1180 | tokenizer=tokenizer, 1181 | args=training_args, 1182 | **data_module) 1183 | 1184 | if list(Path(training_args.output_dir).glob("checkpoint-*")): 1185 | trainer.train(resume_from_checkpoint=True) 1186 | else: 1187 | trainer.train() 1188 | trainer.save_state() 1189 | 1190 | model.config.use_cache = True 1191 | 1192 | if training_args.lora_enable: 1193 | state_dict = get_peft_state_maybe_zero_3( 1194 | model.named_parameters(), training_args.lora_bias 1195 | ) 1196 | non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( 1197 | model.named_parameters() 1198 | ) 1199 | if training_args.local_rank == 0 or training_args.local_rank == -1: 1200 | model.config.save_pretrained(training_args.output_dir) 1201 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 1202 | torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) 1203 | else: 1204 | safe_save_model_for_hf_trainer(trainer=trainer, 1205 | output_dir=training_args.output_dir) 1206 | 1207 | 1208 | if __name__ == "__main__": 1209 | train() 1210 | --------------------------------------------------------------------------------