├── Evaluation_Data_JSON ├── LOLv2_test_data.json ├── MultiGen_canny_data.json ├── MultiGen_hed_data.json ├── NYU_test_data.json ├── Rain100L_test_data.json ├── SOTS_test_data.json ├── UDC_test_data.json ├── UDC_val_data.json ├── inpainting_data.json └── outpainting_data.json ├── Explanatory_Instructions_Tuning ├── ckpts │ └── chameleon │ │ └── tokenizer │ │ ├── text_tokenizer.json │ │ └── vqgan.yaml ├── configs │ └── data │ │ ├── 448exp.yaml │ │ └── sample.yaml ├── data │ ├── __init__.py │ ├── convertsation.py │ └── item_processor.py ├── demo_inference.py ├── demo_input_imgs │ ├── hed_input_6.jpg │ └── rgb_00015.jpg ├── exps │ └── 7B.sh ├── finetune_solver.py ├── inference_solver.py ├── json │ ├── allweather_data.json │ └── edit_resolution_448 │ │ └── Allweather │ │ ├── files │ │ ├── 0_0.pkl │ │ ├── 0_1.pkl │ │ ├── 1_0.pkl │ │ ├── 1_1.pkl │ │ ├── 2_0.pkl │ │ └── 2_1.pkl │ │ └── record.json ├── model │ ├── __init__.py │ ├── chameleon │ │ ├── __init__.py │ │ ├── configuration_chameleon.py │ │ ├── convert_chameleon_weights_to_hf.py │ │ ├── image_processing_chameleon.py │ │ ├── modeling_chameleon.py │ │ └── processing_chameleon.py │ ├── chameleon_vae_ori │ │ ├── __init__.py │ │ ├── image_tokenizer.py │ │ ├── vocab.py │ │ └── vqgan.py │ ├── configuration_xllmx_chameleon.py │ └── modeling_xllmx_chameleon.py ├── pre_tokenize │ ├── Adobe_5k_pre_tokenize.py │ ├── allweather_pre_tokenize.py │ ├── concat_record.py │ └── seed_multi_turn_pre_tokenize.py └── xllmx │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── conversation │ │ ├── __init__.py │ │ └── template.py │ ├── data_reader.py │ ├── dataset.py │ ├── item_processor.py │ └── sampler.py │ ├── model │ ├── __init__.py │ ├── components.py │ └── tokenizer.py │ ├── solvers │ ├── __init__.py │ └── finetune │ │ ├── __init__.py │ │ └── finetune.py │ └── util │ ├── __init__.py │ ├── ckpt.py │ ├── dist.py │ ├── lr_sched.py │ ├── misc.py │ └── tensor_type.py ├── README.md ├── assets ├── 0016_0.8_0.08.jpg ├── 0016_gt.jpg ├── 0016_output.jpg ├── 0017_0.9_0.08.jpg ├── 0017_gt.jpg ├── 0017_output.jpg ├── 0018_0.9_0.2.jpg ├── 0018_gt.jpg ├── 0018_output.jpg ├── canny_edge_source_1.png ├── deblur_input.jpg ├── deblur_output_1.jpg ├── deblur_output_2.jpg ├── deblur_output_3.jpg ├── deblur_output_4.jpg ├── dehazing_input.jpg ├── dehazing_output_1.jpg ├── dehazing_output_2.jpg ├── dehazing_output_3.jpg ├── dehazing_output_4.jpg ├── depth_gt_15.png ├── depth_gt_18.png ├── depth_gt_21.png ├── depth_output_15.jpg ├── depth_output_18.jpg ├── depth_output_21.jpg ├── deraining_gt_5.jpg ├── deraining_gt_6.jpg ├── deraining_gt_9.jpg ├── deraining_input.jpg ├── deraining_output_1.jpg ├── deraining_output_2.jpg ├── deraining_output_3.jpg ├── deraining_output_4.jpg ├── deraining_output_5.jpg ├── deraining_output_6.jpg ├── deraining_output_9.jpg ├── desnow_input.jpg ├── desnow_output_1.jpg ├── desnow_output_2.jpg ├── desnow_output_3.jpg ├── desnow_output_4.jpg ├── hed_gt_10.png ├── hed_gt_14.png ├── hed_gt_6.png ├── hed_input_10.jpg ├── hed_input_14.jpg ├── hed_input_6.png ├── hed_output_10.png ├── hed_output_14.png ├── hed_output_6.png ├── norain-5x2.jpg ├── norain-6x2.jpg ├── norain-9x2.jpg ├── rgb_00015.jpg ├── rgb_00018.jpg ├── rgb_00021.jpg ├── seg_gt_1.jpg ├── seg_gt_2.jpg ├── seg_gt_3.jpg ├── seg_input_1.jpg ├── seg_input_2.jpg ├── seg_input_3.jpg ├── seg_output_1.jpg ├── seg_output_2.jpg ├── seg_output_3.jpg ├── surface_gt_15.jpg ├── surface_gt_18.jpg ├── surface_gt_21.jpg ├── surface_output_15.jpg ├── surface_output_18.jpg ├── surface_output_21.jpg ├── zero_shot_canny_edge_source_1_out_1.jpg ├── zero_shot_canny_edge_source_1_out_2.jpg ├── zero_shot_canny_edge_source_1_out_3.jpg ├── zero_shot_canny_edge_source_1_out_4.jpg ├── zs_low_light_1.jpg ├── zs_low_light_2.jpg ├── zs_low_light_3.jpg ├── zs_low_light_4.jpg └── zs_low_light_input.jpg ├── paper.pdf └── requirements.txt /Explanatory_Instructions_Tuning/ckpts/chameleon/tokenizer/vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 8192 7 | ddconfig: 8 | double_z: false 9 | z_channels: 256 10 | resolution: 512 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | # lossconfig: 24 | # target: taming.modules.losses.vqperceptual_vit_vqgan.VQLPIPSWithDiscriminator 25 | # params: 26 | # disc_start: 100001 27 | # perceptual_weight: 1.0 28 | # adversarial_weight: 0.5 29 | # disc_params: 30 | # size: 512 31 | ckpt_path: manifold://fair_onellm_checkpoints/tree/v2/tokenizer/vqgan_wm_0209.ckpt 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 4 36 | num_workers: 10 37 | image_size: 512 38 | filter_image_size: 512 39 | dataset: coco 40 | aesthetics_th: 0 41 | clipsim_th: 0 42 | --distributed-world-size: null 43 | '32': null 44 | --distributed-port: null 45 | '17338': null 46 | --save-dir: null 47 | /checkpoint/shellysheynin/shutterstock/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN: 48 | log_every-500: 49 | ngpu32: null 50 | --tensorboard-logdir: null 51 | /checkpoint/shellysheynin/tensorboard_logs/2023-03-30/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN: 52 | log_every-500: 53 | ngpu32: null 54 | '14561': null 55 | /checkpoint/shellysheynin/tensorboard_logs/2023-04-02/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN: 56 | log_every-500: 57 | ngpu32: null 58 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/configs/data/448exp.yaml: -------------------------------------------------------------------------------- 1 | META: 2 | - path: './json/edit_resolution_448/Adobe_5k/record.json' 3 | ratio: 2.0 4 | - path: './json/edit_resolution_448/Allweather/record.json' 5 | ratio: 2.0 6 | - path: './json/edit_resolution_448/SEED/Real_Editing/record.json' 7 | ratio: 2.0 8 | - path: './json/edit_resolution_448/SEED/Multi_Turn/record.json' 9 | ratio: 2.0 10 | 11 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/configs/data/sample.yaml: -------------------------------------------------------------------------------- 1 | META: 2 | - path: './json/edit_resolution_448/Allweather/record.json' 3 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/data/__init__.py -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/data/convertsation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class Conversation: 5 | sep_token = "" 6 | roles = ["Human", "Assistant"] 7 | 8 | def __init__(self, messages=None): 9 | self.messages = messages or [] 10 | 11 | def process(self): 12 | ret = "" 13 | pieces = [] 14 | for i, (role, message) in enumerate(self.messages): 15 | if message is not None: 16 | turn = message + self.sep_token 17 | ret += turn 18 | if role == self.roles[1]: 19 | pieces.append({"data": turn, "predict": True}) 20 | else: 21 | pieces.append({"data": turn, "predict": False}) 22 | else: 23 | # generation prompt 24 | assert i == len(self.messages) - 1 and role == self.roles[1], "only last assistant message can be None" 25 | 26 | result = { 27 | "conv": ret, # text involving the complete conversation 28 | "pieces": pieces, # list to help correctly mark the labels 29 | } 30 | return result 31 | 32 | def get_prompt(self): 33 | return self.process()["conv"] 34 | 35 | def append_message(self, role, message): 36 | self.messages.append([role, message]) 37 | 38 | def copy(self): 39 | return Conversation( 40 | messages=[[x, y] for x, y in self.messages], 41 | ) 42 | 43 | def load_qas(self, qas: List[List[str]]): 44 | """ 45 | convert the list of question-answer pairs to a string, which contains the conversation involving all 46 | the questions and answers. When the last answer is None, the returned string is the prompt which 47 | can be used by the model to generate the last answer. 48 | :param qas: [[question1, answer1], [question2, answer2], ..., [questionX, answerX]] 49 | note that the last answer, i.e. answerX, can be None 50 | :return: the prompt 51 | """ 52 | self.messages = [] 53 | for q, a in qas: 54 | self.append_message(self.roles[0], q) 55 | self.append_message(self.roles[1], a) 56 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/data/item_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import random 4 | from typing import Dict, List 5 | 6 | from PIL import Image 7 | import torch 8 | 9 | from data.convertsation import Conversation 10 | import model.chameleon_vae_ori as chameleon_vae_ori 11 | from xllmx.data.data_reader import read_general 12 | from xllmx.data.item_processor import MMConvItemProcessor 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def center_crop(pil_image, crop_size): 18 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]: 19 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 20 | 21 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1]) 22 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 23 | 24 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0]) 25 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1]) 26 | crop_right = crop_left + crop_size[0] 27 | crop_lower = crop_upper + crop_size[1] 28 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower)) 29 | 30 | 31 | def var_center_crop(pil_image, crop_size_list, random_top_k=1): 32 | w, h = pil_image.size 33 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list] 34 | crop_size = random.choice( 35 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k] 36 | )[1] 37 | return center_crop(pil_image, crop_size) 38 | 39 | 40 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0): 41 | assert max_ratio >= 1.0 42 | crop_size_list = [] 43 | wp, hp = num_patches, 1 44 | while wp > 0: 45 | if max(wp, hp) / min(wp, hp) <= max_ratio: 46 | crop_size_list.append((wp * patch_size, hp * patch_size)) 47 | if (hp + 1) * wp <= num_patches: 48 | hp += 1 49 | else: 50 | wp -= 1 51 | return crop_size_list 52 | 53 | 54 | class FlexARItemProcessor(MMConvItemProcessor): 55 | image_start_token = "" # fixed tokens for start and end, so can hardcode 56 | image_end_token = "" 57 | full_sub_sep_token = "" 58 | sub_sub_sep_token = "" 59 | sub_skip_token = "" 60 | new_line_token = "" 61 | 62 | def __init__( 63 | self, 64 | tokenizer="Alpha-VLLM/Lumina-mGPT-7B-768", 65 | conv_template=Conversation, 66 | target_size=512, 67 | with_decoder=False, 68 | ): 69 | 70 | super().__init__( 71 | { 72 | "<|image|>": self.process_image, 73 | }, 74 | ["<|image|>"], 75 | tokenizer, 76 | conv_template, 77 | ) 78 | 79 | self.patch_size = 32 80 | self.crop_size_list = generate_crop_size_list((target_size // self.patch_size) ** 2, self.patch_size) 81 | logger.info("List of crop sizes:") 82 | for i in range(0, len(self.crop_size_list), 6): 83 | logger.info(" " + "".join([f"{f'{w} x {h}':14s}" for w, h in self.crop_size_list[i : i + 6]])) 84 | 85 | # todo 86 | # currently still use the original image tokenizer provided by Meta rather than transformers 87 | # because the transformers implementation does not contain the vae decoder 88 | self.chameleon_ori_vocab = chameleon_vae_ori.VocabInfo( 89 | json.load(open("./ckpts/chameleon/tokenizer/text_tokenizer.json", encoding="utf8"))["model"]["vocab"] 90 | ) 91 | self.chameleon_ori_translation = chameleon_vae_ori.VocabTranslation(self.chameleon_ori_vocab, device="cuda") 92 | self.chameleon_ori_image_tokenizer = chameleon_vae_ori.ImageTokenizer( 93 | cfg_path="./ckpts/chameleon/tokenizer/vqgan.yaml", 94 | ckpt_path="./ckpts/chameleon/tokenizer/vqgan.ckpt", 95 | device="cuda", 96 | ) 97 | 98 | @staticmethod 99 | def get_n_grids_token(n_grids): 100 | return f"" 101 | 102 | def token2id(self, token: str) -> int: 103 | return self.tokenizer.tokenizer.vocab[token] 104 | 105 | @torch.no_grad() 106 | def process_image(self, image) -> Dict: 107 | if isinstance(image, Image.Image): 108 | pass 109 | else: 110 | image = Image.open(read_general(image)) 111 | 112 | image = var_center_crop(image, crop_size_list=self.crop_size_list) 113 | 114 | w_grids, h_grids = image.size[0] // self.patch_size, image.size[1] // self.patch_size 115 | 116 | image_toks = self.chameleon_ori_translation.convert_img2bp2( 117 | self.chameleon_ori_image_tokenizer.img_tokens_from_pil(image) 118 | ).view(-1) 119 | 120 | full_image_toks = image_toks.reshape(image.size[1] // 16, image.size[0] // 16) 121 | new_line_id = self.token2id(self.new_line_token) 122 | 123 | full_image_toks = torch.cat( 124 | ( 125 | full_image_toks, 126 | torch.ones(image.size[1] // 16, 1, device=full_image_toks.device, dtype=full_image_toks.dtype) 127 | * new_line_id, 128 | ), 129 | dim=1, 130 | ).flatten() 131 | 132 | result_toks = [ 133 | self.token2id(self.image_start_token), 134 | self.token2id(self.get_n_grids_token(h_grids)), 135 | self.token2id(self.get_n_grids_token(w_grids)), 136 | *full_image_toks.tolist(), 137 | self.token2id(self.image_end_token), 138 | ] 139 | 140 | return {"input_ids": result_toks, "labels": result_toks} 141 | 142 | def process_item(self, item, training_mode=False, out_flatten=True): 143 | if not out_flatten: 144 | return super().process_item(item, training_mode=training_mode) 145 | 146 | if training_mode: 147 | tokens, labels = super().process_item(item, training_mode=training_mode) 148 | input_tokens_item = [] 149 | modified_labels_item = [] 150 | for i, (token_or_media, ori_label) in enumerate(zip(tokens, labels)): 151 | if isinstance(token_or_media, int): 152 | token = token_or_media 153 | input_tokens_item.append(token) 154 | modified_labels_item.append(ori_label) 155 | else: 156 | input_tokens_item += token_or_media["input_ids"] 157 | if ori_label <= 0: # in the prompt part 158 | modified_labels_item += [-100] * len(token_or_media["input_ids"]) 159 | else: 160 | modified_labels_item += token_or_media["labels"] 161 | 162 | return input_tokens_item, modified_labels_item 163 | else: 164 | tokens = super().process_item(item, training_mode=training_mode) 165 | input_tokens_item = [] 166 | for i, token_or_media in enumerate(tokens): 167 | if isinstance(token_or_media, int): 168 | input_tokens_item.append(token_or_media) 169 | else: 170 | input_tokens_item += token_or_media["input_ids"] 171 | 172 | return input_tokens_item 173 | 174 | def decode_image(self, tokens: List[int]) -> Image.Image: 175 | if tokens[0] == self.token2id(self.image_start_token): 176 | tokens = tokens[1:] 177 | if tokens[-1] == self.token2id(self.image_end_token): 178 | tokens = tokens[:-1] 179 | 180 | h_grids, w_grids = tokens[0] - 8804, tokens[1] - 8804 181 | tokens = tokens[2:] 182 | h, w = h_grids * self.patch_size, w_grids * self.patch_size 183 | h_latent_dim, w_latent_dim = h_grids * 2, w_grids * 2 184 | 185 | for i in range(len(tokens)): 186 | if (i + 1) % (w_latent_dim + 1) != 0: 187 | tokens[i] = self.chameleon_ori_translation.bpe2img[tokens[i]] 188 | 189 | assert len(tokens) == h_latent_dim * (w_latent_dim + 1) 190 | tokens = torch.tensor(tokens, dtype=torch.int64).cuda() 191 | 192 | tokens = tokens.view(h_latent_dim, w_latent_dim + 1)[:, :-1].flatten() 193 | 194 | return self.chameleon_ori_image_tokenizer.pil_from_img_toks(tokens, h_latent_dim, w_latent_dim) 195 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/demo_inference.py: -------------------------------------------------------------------------------- 1 | from inference_solver import FlexARInferenceSolver 2 | from PIL import Image 3 | import os 4 | import torch 5 | import random 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | 17 | inference_solver = FlexARInferenceSolver( 18 | model_path = "UVT_7B_448", #path to your model 19 | precision="fp16", #bf16 20 | target_size=448, #fixed 448 21 | ) 22 | 23 | max_out = 1 24 | for i in range(max_out): 25 | set_seed(i) 26 | 27 | ### Fixed format: Instruction + input image --> output image 28 | qas = [["Acknowledge the spatial structure and identify variations in light intensity, translating these into a gradient scale representing distances. Accentuate regions where light diminishes gradually, enhancing the perception of depth by dimming peripheral areas. Adjust the distribution of luminance to highlight the central vanishing point, converting detailed textures into smooth transitions of grayscale." + " <|image|>", None]] 29 | images = [Image.open("./demo_input_imgs/rgb_00015.jpg")] 30 | 31 | # qas = [["Translate the visible structures into a range of bright colors reflecting orientation angles, enhancing variations across surfaces." + " <|image|>", None]] 32 | # images = [Image.open("./demo_input_imgs/rgb_00015.jpg")] 33 | 34 | # qas = [["Capture the outline and prominent edges of the cylindrical object and its surroundings, simplify everything by removing textures and detailed surfaces, and emphasize only the contours and distinct features while rendering a higher contrast between light and dark regions with sharp shifts in tones." + " <|image|>", None]] 35 | # images = [Image.open("./demo_input_imgs/hed_input_6.jpg")] 36 | 37 | generated = inference_solver.generate( 38 | images=images, 39 | qas=qas, 40 | max_gen_len=4096, 41 | temperature=1.0, 42 | logits_processor=inference_solver.create_logits_processor(cfg=1., image_top_k=2048), 43 | ) 44 | new_image = generated[1][0] 45 | new_image.save(f'./demo_output_{i}.png', format='PNG') -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/demo_input_imgs/hed_input_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/demo_input_imgs/hed_input_6.jpg -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/demo_input_imgs/rgb_00015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/demo_input_imgs/rgb_00015.jpg -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/exps/7B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lr=2e-5 4 | min_lr=2e-6 5 | wd=0.1 6 | dropout=0.05 7 | z_loss_weight=1e-5 8 | 9 | data_config=configs/data/sample.yaml 10 | 11 | exp_name=7B_task_A_to_B_lr${lr}_min_lr${min_lr} 12 | mkdir -p output/"$exp_name" 13 | 14 | 15 | # python -u finetune_solver.py \ 16 | # python finetune_solver.py 17 | torchrun --nproc_per_node=8 --master_port=25641 finetune_solver.py \ 18 | --model_size 7B \ 19 | --batch_size 4 \ 20 | --accum_iter 1 \ 21 | --epochs 3 \ 22 | --warmup_epochs 0.01 \ 23 | --lr ${lr} \ 24 | --min_lr ${min_lr} \ 25 | --wd ${wd} \ 26 | --clip_grad 4 \ 27 | --data_config $data_config \ 28 | --cache_ann_on_disk \ 29 | --num_workers 8 \ 30 | --output_dir output/"$exp_name" \ 31 | --save_iteration_interval 4000 \ 32 | --checkpointing \ 33 | --max_seq_len 2048 \ 34 | --unmask_image_logits \ 35 | --dropout ${dropout} \ 36 | --z_loss_weight ${z_loss_weight} \ 37 | 2>&1 | tee -a output/"$exp_name"/output.log 38 | 39 | echo "exp name: $exp_name" 40 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/finetune_solver.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import List, Tuple 3 | 4 | from accelerate import init_empty_weights 5 | import torch 6 | 7 | from model import ChameleonXLLMXConfig, ChameleonXLLMXForConditionalGeneration 8 | from xllmx.data.item_processor import ItemProcessorBase 9 | from xllmx.solvers.finetune import FinetuneSolverBase 10 | 11 | 12 | class ItemProcessor(ItemProcessorBase): 13 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]: 14 | assert training_mode 15 | 16 | if "token" in data_item and "label" in data_item: 17 | data_item = data_item 18 | else: 19 | assert "file" in data_item 20 | with open(data_item["file"], "rb") as f: 21 | data_item = pickle.load(f) 22 | 23 | tokens = data_item["token"] 24 | labels = data_item["label"] 25 | assert len(tokens) == len(labels) 26 | 27 | return tokens, labels 28 | 29 | def predict_item_token_length(self, data_item: dict) -> int: 30 | if "token" in data_item: 31 | return len(data_item["token"]) 32 | elif "len" in data_item: 33 | return data_item["len"] 34 | else: 35 | raise ValueError() 36 | 37 | 38 | class Solver(FinetuneSolverBase): 39 | @classmethod 40 | def get_args_parser(cls): 41 | parser = super().get_args_parser() 42 | # task-specific parameters 43 | parser.add_argument("--max_seq_len", default=4096, type=int, help="max token length") 44 | parser.add_argument("--mask_image_logits", default=True) 45 | parser.add_argument("--unmask_image_logits", action="store_false", dest="mask_image_logits") 46 | parser.add_argument("--dropout", type=float, default=0.0) 47 | parser.add_argument("--z_loss_weight", type=float, default=0.0) 48 | parser.add_argument("--model_size", type=str, default="7B", choices=["7B", "34B"]) 49 | return parser 50 | 51 | def _model_func( 52 | self, 53 | init_from: str, 54 | ) -> (ChameleonXLLMXForConditionalGeneration, None): 55 | 56 | # Only instantiate the model on rank0 57 | # Other ranks will receive the model weights from rank0 during FSDP wrapping (through `sync_module_states`) 58 | # See https://github.com/pytorch/pytorch/issues/105840 59 | if self.dp_rank == 0: 60 | model = ChameleonXLLMXForConditionalGeneration.from_pretrained( 61 | init_from, 62 | max_position_embeddings=self.args.max_seq_len, 63 | mask_image_logits=self.args.mask_image_logits, 64 | dropout=self.args.dropout, 65 | z_loss_weight=self.args.z_loss_weight, 66 | torch_dtype=torch.bfloat16, 67 | device_map="cpu", 68 | ) 69 | else: 70 | with init_empty_weights(): 71 | config = ChameleonXLLMXConfig.from_pretrained( 72 | init_from, 73 | max_position_embeddings=self.args.max_seq_len, 74 | mask_image_logits=self.args.mask_image_logits, 75 | dropout=self.args.dropout, 76 | z_loss_weight=self.args.z_loss_weight, 77 | torch_dtype=torch.bfloat16, 78 | ) 79 | model = ChameleonXLLMXForConditionalGeneration(config) 80 | 81 | del model.model.vqmodel 82 | 83 | return model, None 84 | 85 | def _item_processor_func(self) -> ItemProcessorBase: 86 | return ItemProcessor() 87 | 88 | def _make_and_save_starting_point(self, save_path: str) -> None: 89 | 90 | pretrained_name = { 91 | "7B": "Alpha-VLLM/Chameleon_7B_mGPT", 92 | "34B": "Alpha-VLLM/Chameleon_34B_mGPT", 93 | }[self.args.model_size] 94 | 95 | model = ChameleonXLLMXForConditionalGeneration.from_pretrained( 96 | pretrained_name, 97 | max_position_embeddings=self.args.max_seq_len, 98 | mask_image_logits=self.args.mask_image_logits, 99 | dropout=self.args.dropout, 100 | z_loss_weight=self.args.z_loss_weight, 101 | torch_dtype=torch.bfloat16, 102 | device_map="cpu", 103 | ) 104 | 105 | image_tokens = model.model.vocabulary_mapping.image_tokens 106 | model.lm_head.weight.data[image_tokens] = torch.zeros_like(model.lm_head.weight.data[image_tokens]) 107 | 108 | model.save_pretrained(save_path, max_shard_size="10GB") 109 | 110 | 111 | if __name__ == "__main__": 112 | args = Solver.get_args_parser().parse_args() 113 | solver = Solver(args) 114 | solver.run() 115 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/inference_solver.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | from typing import List, Optional, Union 5 | 6 | from PIL import Image 7 | import torch 8 | import transformers 9 | from transformers import GenerationConfig, TextStreamer 10 | from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList, LogitsWarper 11 | 12 | from data.item_processor import FlexARItemProcessor 13 | from model.chameleon import ChameleonForConditionalGeneration 14 | 15 | 16 | class LLMImageStartTriggeredUnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): 17 | r""" 18 | Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores 19 | from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`. 20 | The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch. 21 | 22 | See [the paper](https://arxiv.org/abs/2306.17806) for more information. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | guidance_scale: float, 28 | model, 29 | image_start_token_id, 30 | image_end_token_id, 31 | image_next_line_token_id, 32 | patch_size, 33 | unconditional_ids: Optional[torch.LongTensor] = None, 34 | unconditional_attention_mask: Optional[torch.LongTensor] = None, 35 | use_cache: Optional[bool] = True, 36 | ): 37 | self.guidance_scale = guidance_scale 38 | self.model = model 39 | self.unconditional_context_backup = { 40 | "input_ids": unconditional_ids, 41 | "attention_mask": unconditional_attention_mask, 42 | "use_cache": use_cache, 43 | "past_key_values": transformers.DynamicCache() if use_cache else None, 44 | "first_pass": True, 45 | } 46 | self.unconditional_context = None 47 | 48 | self.nums_image_start_tokens = None 49 | 50 | self.image_start_token_id = image_start_token_id 51 | self.image_end_token_id = image_end_token_id 52 | self.image_next_line_token_id = image_next_line_token_id 53 | self.image_start_token_id_index = None 54 | self.patch_size = patch_size 55 | self.h_latent_dim = None 56 | self.w_latent_dim = None 57 | 58 | def get_unconditional_logits(self, input_ids, image_start_token_id_index): 59 | 60 | if self.unconditional_context["first_pass"]: 61 | if self.unconditional_context["input_ids"] is None: 62 | self.unconditional_context["input_ids"] = input_ids[:, image_start_token_id_index:] 63 | if self.unconditional_context["attention_mask"] is None: 64 | self.unconditional_context["attention_mask"] = torch.ones_like( 65 | self.unconditional_context["input_ids"], dtype=torch.long 66 | ) 67 | input_ids = self.unconditional_context["input_ids"] 68 | attention_mask = self.unconditional_context["attention_mask"] 69 | self.unconditional_context["first_pass"] = False 70 | else: 71 | attention_mask = torch.cat( 72 | [ 73 | self.unconditional_context["attention_mask"], 74 | torch.ones_like(input_ids[:, -1:], dtype=torch.long), 75 | ], 76 | dim=1, 77 | ) 78 | if not self.unconditional_context["use_cache"]: 79 | input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1) 80 | else: 81 | input_ids = input_ids[:, -1:] 82 | self.unconditional_context["input_ids"] = input_ids 83 | self.unconditional_context["attention_mask"] = attention_mask 84 | 85 | out = self.model( 86 | input_ids, 87 | attention_mask=attention_mask, 88 | use_cache=self.unconditional_context["use_cache"], 89 | past_key_values=self.unconditional_context["past_key_values"], 90 | ) 91 | self.unconditional_context["past_key_values"] = out.get("past_key_values", None) 92 | 93 | return out.logits 94 | 95 | def __call__(self, input_ids, scores): 96 | num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum() 97 | num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum() 98 | 99 | if num_image_start_tokens == num_image_end_tokens: 100 | self.h_latent_dim, self.w_latent_dim = None, None 101 | self.image_start_token_id_index = None 102 | self.unconditional_context = None 103 | return scores 104 | 105 | elif num_image_start_tokens == num_image_end_tokens + 1: 106 | if self.image_start_token_id_index is None: 107 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0][-1].item() 108 | new_token_num = len(input_ids[0][self.image_start_token_id_index + 1 :]) 109 | if new_token_num >= 2: 110 | if self.h_latent_dim is None or self.w_latent_dim is None: 111 | h_grids, w_grids = ( 112 | input_ids[0][self.image_start_token_id_index + 1] - 8804, 113 | input_ids[0][self.image_start_token_id_index + 2] - 8804, 114 | ) 115 | self.h_latent_dim, self.w_latent_dim = h_grids * 2, w_grids * 2 116 | 117 | if self.unconditional_context is None: 118 | self.unconditional_context = copy.deepcopy(self.unconditional_context_backup) 119 | 120 | if self.guidance_scale == 1.0: 121 | return scores 122 | 123 | unconditional_logits = self.get_unconditional_logits(input_ids, self.image_start_token_id_index)[:, -1] 124 | 125 | scores_processed = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits 126 | return scores_processed 127 | 128 | else: 129 | print("Something wrong in the decoding process.") 130 | 131 | return scores 132 | 133 | 134 | class MultiModalLogitsProcessor(LogitsProcessor): 135 | 136 | def __init__( 137 | self, 138 | image_start_token_id=None, 139 | image_end_token_id=None, 140 | image_next_line_token_id=None, 141 | patch_size=None, 142 | voc_size=None, 143 | ): 144 | self.image_start_token_id = image_start_token_id 145 | self.image_end_token_id = image_end_token_id 146 | self.image_next_line_token_id = image_next_line_token_id 147 | self.image_start_token_id_index = None 148 | self.patch_size = patch_size 149 | self.h_latent_dim = None 150 | self.w_latent_dim = None 151 | 152 | self.vocab_list = [i for i in range(voc_size)] 153 | self.image_token_list = [i for i in range(4, 8195 + 1)] 154 | self.suppress_tokens = torch.tensor( 155 | [x for x in self.vocab_list if x not in self.image_token_list], device="cuda" 156 | ) 157 | 158 | self.vocab_tensor = torch.arange(voc_size, device="cuda") 159 | self.suppress_token_mask = torch.isin(self.vocab_tensor, self.suppress_tokens) 160 | self.new_line_force_token_mask = torch.isin( 161 | self.vocab_tensor, torch.tensor([self.image_next_line_token_id], device="cuda") 162 | ) 163 | self.eos_image_force_token_mask = torch.isin( 164 | self.vocab_tensor, torch.tensor([self.image_end_token_id], device="cuda") 165 | ) 166 | 167 | self.flag = False 168 | self.num_image_start_tokens = None 169 | self.num_image_end_tokens = None 170 | 171 | # @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 172 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 173 | 174 | self.num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum() 175 | self.num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum() 176 | 177 | # print(self.num_image_start_tokens, self.num_image_end_tokens) 178 | 179 | if self.num_image_start_tokens == self.num_image_end_tokens: 180 | self.h_latent_dim, self.w_latent_dim = None, None 181 | self.image_start_token_id_index = None 182 | return scores 183 | 184 | elif self.num_image_start_tokens == self.num_image_end_tokens + 1: 185 | if self.image_start_token_id_index is None: 186 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0] 187 | print(self.image_start_token_id_index) 188 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0][-1].item() 189 | 190 | new_token_num = len(input_ids[0][self.image_start_token_id_index + 1 :]) 191 | # print(f"num new tokens: {new_token_num}") 192 | if new_token_num >= 2: 193 | if self.h_latent_dim is None or self.w_latent_dim is None: 194 | h_grids, w_grids = ( 195 | input_ids[0][self.image_start_token_id_index + 1] - 8804, 196 | input_ids[0][self.image_start_token_id_index + 2] - 8804, 197 | ) 198 | # print(f"h_grids: {h_grids}, w_grids: {w_grids}") 199 | self.h_latent_dim, self.w_latent_dim = h_grids * 2, w_grids * 2 200 | print(f"h_latent_dim: {self.h_latent_dim}, w_latent_dim: {self.w_latent_dim}") 201 | 202 | tokens = input_ids[0][self.image_start_token_id_index + 3 :] 203 | if (len(tokens) + 1) % (self.w_latent_dim + 1) == 0: 204 | new_line_constrained_scores = torch.full_like(scores, -math.inf) 205 | new_line_constrained_scores[:, self.image_next_line_token_id] = 0 206 | # print(f"new line: {len(tokens)+1}") 207 | return new_line_constrained_scores 208 | elif (len(tokens) + 1) == (self.w_latent_dim + 1) * self.h_latent_dim + 1: 209 | eos_image_constrained_scores = torch.full_like(scores, -math.inf) 210 | eos_image_constrained_scores[:, self.image_end_token_id] = 0 211 | print(f"eos image: {len(tokens)+1}") 212 | return eos_image_constrained_scores 213 | elif (len(tokens) + 1) % (self.w_latent_dim + 1) != 0: 214 | image_constrained_scores = torch.where(self.suppress_token_mask, -float("inf"), scores) 215 | return image_constrained_scores 216 | else: 217 | print("Something wrong in the decoding process.") 218 | 219 | return scores 220 | 221 | 222 | class InterleavedTopKLogitsWarper(LogitsWarper): 223 | r""" 224 | [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together 225 | with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. 226 | """ 227 | 228 | def __init__( 229 | self, 230 | image_top_k: int, 231 | text_top_k: int, 232 | image_start_token_id=None, 233 | image_end_token_id=None, 234 | filter_value: float = -float("Inf"), 235 | min_tokens_to_keep: int = 1, 236 | ): 237 | if not isinstance(text_top_k, int) or text_top_k <= 0: 238 | raise ValueError(f"`text_top_k` has to be a strictly positive integer, but is {text_top_k}") 239 | if not isinstance(image_top_k, int) or text_top_k <= 0: 240 | raise ValueError(f"`image_top_k` has to be a strictly positive integer, but is {image_top_k}") 241 | 242 | self.image_top_k = max(image_top_k, min_tokens_to_keep) 243 | self.text_top_k = max(text_top_k, min_tokens_to_keep) 244 | self.filter_value = filter_value 245 | 246 | self.image_start_token_id = image_start_token_id 247 | self.image_end_token_id = image_end_token_id 248 | 249 | self.flag = False 250 | self.num_image_start_tokens = None 251 | self.num_image_end_tokens = None 252 | 253 | # @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 254 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 255 | 256 | self.num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum() 257 | self.num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum() 258 | 259 | if self.num_image_start_tokens == self.num_image_end_tokens + 1: 260 | top_k = min(self.image_top_k, scores.size(-1)) 261 | else: 262 | top_k = min(self.text_top_k, scores.size(-1)) # Safety check 263 | # Remove all tokens with a probability less than the last token of the top-k 264 | indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] 265 | scores_processed = scores.masked_fill(indices_to_remove, self.filter_value) 266 | return scores_processed 267 | 268 | 269 | class FlexARInferenceSolver: 270 | @classmethod 271 | def get_args_parser(cls): 272 | parser = argparse.ArgumentParser("xllmx Inference", add_help=False) 273 | parser.add_argument("--model_path", type=str) 274 | parser.add_argument("--precision", type=str, choices=["fp16", "bf16", "tf32"], default="bf16") 275 | 276 | return parser 277 | 278 | def __init__(self, model_path, precision, target_size=512): 279 | self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 280 | 281 | self.model = ChameleonForConditionalGeneration.from_pretrained( 282 | model_path, 283 | torch_dtype=self.dtype, 284 | device_map="cuda", 285 | ) 286 | self.item_processor = FlexARItemProcessor(target_size=target_size) 287 | 288 | def get_streamer(self): 289 | return TextStreamer(self.item_processor.tokenizer) 290 | 291 | @torch.no_grad() 292 | def generate( 293 | self, 294 | images: Image.Image | str | List[Union[Image.Image, str]], 295 | qas, 296 | max_gen_len, 297 | temperature, 298 | logits_processor=None, 299 | streamer=None, 300 | ): 301 | 302 | conversations = [] 303 | for q, a in qas: 304 | conversations.append( 305 | { 306 | "from": "human", 307 | "value": q, 308 | } 309 | ) 310 | conversations.append( 311 | { 312 | "from": "gpt", 313 | "value": a, 314 | } 315 | ) 316 | item = {"image": images, "conversations": conversations} 317 | 318 | _prompt = self.item_processor.process_item(item) 319 | prompt = [] 320 | for value in _prompt: 321 | if isinstance(value, int): 322 | prompt.append(value) 323 | else: 324 | prompt += value["input_ids"] 325 | prompt_len = len(prompt) 326 | prompt = torch.tensor(prompt, dtype=torch.int64, device=self.model.device).unsqueeze(0) 327 | 328 | generation_config = GenerationConfig( 329 | max_new_tokens=max_gen_len, 330 | max_length=self.model.config.max_position_embeddings, 331 | temperature=temperature, 332 | top_k=None, 333 | do_sample=True, 334 | eos_token_id=[8710], 335 | ) 336 | 337 | if logits_processor is None: 338 | logits_processor = self.create_logits_processor() 339 | 340 | with torch.cuda.amp.autocast(dtype=self.dtype): 341 | generation_result = self.model.generate( 342 | prompt, generation_config, logits_processor=logits_processor, streamer=streamer 343 | )[0][prompt_len:].tolist() 344 | if len(generation_result) > 0 and generation_result[-1] == 8710: 345 | generation_result = generation_result[:-1] 346 | 347 | return self.decode_ids(generation_result) 348 | 349 | def decode_ids(self, tokens: List[int]): 350 | generated_images = [] 351 | generation_result_processed = [] 352 | i = 0 353 | while i < len(tokens): 354 | token_id = tokens[i] 355 | if token_id == self.item_processor.token2id(self.item_processor.image_start_token): 356 | cache = [] 357 | for j in range(i + 1, len(tokens)): 358 | if tokens[j] != self.item_processor.token2id(self.item_processor.image_end_token): 359 | cache.append(tokens[j]) 360 | i = j + 1 361 | else: 362 | image = self.decode_image(cache) 363 | generated_images.append(image) 364 | generation_result_processed.append(self.item_processor.token2id("<|image|>")) 365 | i = j + 1 366 | break 367 | else: 368 | generation_result_processed.append(token_id) 369 | i += 1 370 | 371 | generated = self.item_processor.tokenizer.decode(generation_result_processed) 372 | 373 | return generated, generated_images 374 | 375 | def decode_image(self, tokens: List[int]): 376 | return self.item_processor.decode_image(tokens) 377 | 378 | @staticmethod 379 | def create_image_grid(images, rows, cols): 380 | width, height = images[0].size 381 | 382 | grid_img = Image.new("RGB", (cols * width, rows * height)) 383 | 384 | for i, img in enumerate(images): 385 | row = i // cols 386 | col = i % cols 387 | grid_img.paste(img, (col * width, row * height)) 388 | 389 | return grid_img 390 | 391 | def create_logits_processor(self, cfg=3.0, image_top_k=2000, text_top_k=10): 392 | logits_processor = LogitsProcessorList() 393 | 394 | cfg_processor = LLMImageStartTriggeredUnbatchedClassifierFreeGuidanceLogitsProcessor( 395 | guidance_scale=cfg, 396 | model=self.model, 397 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token), 398 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token), 399 | image_next_line_token_id=self.item_processor.token2id(self.item_processor.new_line_token), 400 | patch_size=32, 401 | ) 402 | 403 | candidate_processor = MultiModalLogitsProcessor( 404 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token), 405 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token), 406 | image_next_line_token_id=self.item_processor.token2id(self.item_processor.new_line_token), 407 | patch_size=32, 408 | voc_size=self.model.config.vocab_size, 409 | ) 410 | 411 | topk_processor = InterleavedTopKLogitsWarper( 412 | image_top_k=image_top_k, 413 | text_top_k=text_top_k, 414 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token), 415 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token), 416 | ) 417 | 418 | logits_processor.append(cfg_processor) 419 | logits_processor.append(candidate_processor) 420 | logits_processor.append(topk_processor) 421 | 422 | return logits_processor 423 | 424 | 425 | if __name__ == "__main__": 426 | parser = FlexARInferenceSolver.get_args_parser() 427 | args = parser.parse_args() 428 | solver = FlexARInferenceSolver(**vars(args)) 429 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/0_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/0_0.pkl -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/0_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/0_1.pkl -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/1_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/1_0.pkl -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/1_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/1_1.pkl -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/2_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/2_0.pkl -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/2_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/json/edit_resolution_448/Allweather/files/2_1.pkl -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_xllmx_chameleon import ChameleonXLLMXConfig 2 | from .modeling_xllmx_chameleon import ChameleonXLLMXForConditionalGeneration 3 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/chameleon/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available 17 | 18 | _import_structure = { 19 | "configuration_chameleon": ["ChameleonConfig", "ChameleonVQVAEConfig"], 20 | "processing_chameleon": ["ChameleonProcessor"], 21 | } 22 | 23 | 24 | try: 25 | if not is_torch_available(): 26 | raise OptionalDependencyNotAvailable() 27 | except OptionalDependencyNotAvailable: 28 | pass 29 | else: 30 | _import_structure["modeling_chameleon"] = [ 31 | "ChameleonForConditionalGeneration", 32 | "ChameleonModel", 33 | "ChameleonPreTrainedModel", 34 | "ChameleonVQVAE", 35 | ] 36 | 37 | try: 38 | if not is_vision_available(): 39 | raise OptionalDependencyNotAvailable() 40 | except OptionalDependencyNotAvailable: 41 | pass 42 | else: 43 | _import_structure["image_processing_chameleon"] = ["ChameleonImageProcessor"] 44 | 45 | 46 | if TYPE_CHECKING: 47 | from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig 48 | from .processing_chameleon import ChameleonProcessor 49 | 50 | try: 51 | if not is_torch_available(): 52 | raise OptionalDependencyNotAvailable() 53 | except OptionalDependencyNotAvailable: 54 | pass 55 | else: 56 | from .modeling_chameleon import ( 57 | ChameleonForConditionalGeneration, 58 | ChameleonModel, 59 | ChameleonPreTrainedModel, 60 | ChameleonVQVAE, 61 | ) 62 | 63 | try: 64 | if not is_vision_available(): 65 | raise OptionalDependencyNotAvailable() 66 | except OptionalDependencyNotAvailable: 67 | pass 68 | else: 69 | from .image_processing_chameleon import ChameleonImageProcessor 70 | 71 | 72 | else: 73 | import sys 74 | 75 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 76 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/chameleon/configuration_chameleon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """chameleon model configuration""" 16 | 17 | from typing import List 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class ChameleonVQVAEConfig(PretrainedConfig): 26 | r""" 27 | This is the configuration class to store the configuration of a [`ChameleonVQModel`]. It is used to instantiate a 28 | `ChameleonVQModel` according to the specified arguments, defining the model architecture. 29 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 30 | documentation from [`PretrainedConfig`] for more information. Instantiating a 31 | configuration with the defaults will yield a similar configuration to the VQModel of the 32 | [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B). 33 | 34 | Args: 35 | embed_dim (`int`, *optional*, defaults to 256): 36 | Dimensionality of each embedding vector. 37 | num_embeddings (`int`, *optional*, defaults to 8192): 38 | Number of codebook embeddings. 39 | double_latent (`bool`, *optional*, defaults to `False`): 40 | Whether to use double z channels. 41 | latent_channels (`int`, *optional*, defaults to 256): 42 | Number of channels for the latent space. 43 | resolution (`int`, *optional*, defaults to 512): 44 | Resolution of the input images. 45 | in_channels (`int`, *optional*, defaults to 3): 46 | Number of input channels. 47 | base_channels (`int`, *optional*, defaults to 128): 48 | Base channel count. 49 | channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`): 50 | Channel multipliers for each resolution. 51 | num_res_blocks (`int`, *optional*, defaults to 2): 52 | Number of residual blocks. 53 | attn_resolutions (`List[int]`, *optional*): 54 | Resolutions to apply attention. 55 | dropout (`float`, *optional*, defaults to 0.0): 56 | Dropout rate. 57 | attn_type (`str`, *optional*, defaults to `"vanilla"`): 58 | Attention type used in VQ-GAN encoder. Can be "vanilla" or None. 59 | initializer_range (`float`, *optional*, defaults to 0.02): 60 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 61 | """ 62 | 63 | model_type = "chameleon_vqgan" 64 | 65 | def __init__( 66 | self, 67 | embed_dim: int = 256, 68 | num_embeddings: int = 8192, 69 | double_latent: bool = False, 70 | latent_channels: int = 256, 71 | resolution: int = 512, 72 | in_channels: int = 3, 73 | base_channels: int = 128, 74 | channel_multiplier: List[int] = [1, 1, 2, 2, 4], 75 | num_res_blocks: int = 2, 76 | attn_resolutions: List[int] = None, 77 | dropout: float = 0.0, 78 | attn_type: str = "vanilla", 79 | initializer_range=0.02, 80 | **kwargs, 81 | ): 82 | super().__init__(**kwargs) 83 | self.embed_dim = embed_dim 84 | self.num_embeddings = num_embeddings 85 | self.double_latent = double_latent 86 | self.latent_channels = latent_channels 87 | self.resolution = resolution 88 | self.in_channels = in_channels 89 | self.base_channels = base_channels 90 | self.channel_multiplier = channel_multiplier 91 | self.num_res_blocks = num_res_blocks 92 | self.attn_resolutions = attn_resolutions 93 | self.dropout = dropout 94 | self.attn_type = attn_type 95 | self.initializer_range = initializer_range 96 | 97 | 98 | class ChameleonConfig(PretrainedConfig): 99 | r""" 100 | This is the configuration class to store the configuration of a [`ChameleonModel`]. It is used to instantiate a 101 | chameleon model according to the specified arguments, defining the model architecture. Instantiating a 102 | configuration with the defaults will yield a similar configuration to that of the 103 | [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B). 104 | 105 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 106 | documentation from [`PretrainedConfig`] for more information. 107 | 108 | 109 | Args: 110 | vocab_size (`int`, *optional*, defaults to 65536): 111 | Vocabulary size of the chameleon model. Defines the number of different tokens that can be represented by the 112 | `inputs_ids` passed when calling [`ChameleonModel`]; this includes text and image tokens. 113 | hidden_size (`int`, *optional*, defaults to 4096): 114 | Dimension of the hidden representations. 115 | intermediate_size (`int`, *optional*, defaults to 11008): 116 | Dimension of the MLP representations. 117 | num_hidden_layers (`int`, *optional*, defaults to 32): 118 | Number of hidden layers in the Transformer decoder. 119 | num_attention_heads (`int`, *optional*, defaults to 32): 120 | Number of attention heads for each attention layer in the Transformer decoder. 121 | num_key_value_heads (`int`, *optional*, defaults to 32): 122 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 123 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 124 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 125 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 126 | by meanpooling all the original heads within that group. For more details checkout [this 127 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 128 | `num_attention_heads`. 129 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 130 | The non-linear activation function (function or string) in the decoder. 131 | max_position_embeddings (`int`, *optional*, defaults to 4096): 132 | The maximum sequence length that this model might ever be used with. Chameleon supports up to 4096 tokens. 133 | initializer_range (`float`, *optional*, defaults to 0.02): 134 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 135 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 136 | The epsilon used by the rms normalization layers. 137 | use_cache (`bool`, *optional*, defaults to `True`): 138 | Whether or not the model should return the last key/values attentions (not used by all models). Only 139 | relevant if `config.is_decoder=True`. 140 | pad_token_id (`int`, *optional*): 141 | Padding token id. 142 | bos_token_id (`int`, *optional*, defaults to 1): 143 | Beginning of stream token id. 144 | eos_token_id (`int`, *optional*, defaults to 2): 145 | End of stream token id. 146 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 147 | Whether to tie weight embeddings 148 | rope_theta (`float`, *optional*, defaults to 10000.0): 149 | The base period of the RoPE embeddings. 150 | rope_scaling (`Dict`, *optional*): 151 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 152 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 153 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 154 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 155 | these scaling strategies behave: 156 | https://www.reddit.com/r/Localchameleon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 157 | experimental feature, subject to breaking API changes in future versions. 158 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 159 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 160 | attention_dropout (`float`, *optional*, defaults to 0.0): 161 | The dropout ratio for the attention probabilities. 162 | model_parallel_size (`int`, *optional*, defaults to 1): 163 | Number of shards used when training the model. This will be used in qk layernorm because the original Chameleon inference 164 | doesn't do reduction in those layers and each rank has its own biases. 165 | swin_norm (`bool`, *optional*, defaults to `False`): 166 | Use Swin Transformer normalization. 167 | vq_config (`dict`, *optional*): 168 | ChameleonVQConfig instance containing the configuration for the VQ-VAE model. 169 | vocabulary_map (`dict`, *optional*): 170 | A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. 171 | mlp_bias (`bool`, *optional*, defaults to `False`): 172 | Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. 173 | 174 | 175 | ```python 176 | >>> from transformers import ChameleonModel, ChameleonConfig 177 | 178 | >>> # Initializing a chameleon chameleon-7b style configuration 179 | >>> configuration = ChameleonConfig() 180 | 181 | >>> # Initializing a model from the chameleon-7b style configuration 182 | >>> model = ChameleonModel(configuration) 183 | 184 | >>> # Accessing the model configuration 185 | >>> configuration = model.config 186 | ```""" 187 | 188 | model_type = "chameleon" 189 | keys_to_ignore_at_inference = ["past_key_values"] 190 | 191 | def __init__( 192 | self, 193 | vocab_size=65536, 194 | hidden_size=4096, 195 | intermediate_size=11008, 196 | num_hidden_layers=32, 197 | num_attention_heads=32, 198 | num_key_value_heads=32, 199 | hidden_act="silu", 200 | max_position_embeddings=4096, 201 | initializer_range=0.02, 202 | rms_norm_eps=1e-05, 203 | use_cache=True, 204 | pad_token_id=None, 205 | bos_token_id=1, 206 | eos_token_id=2, 207 | tie_word_embeddings=False, 208 | rope_theta=10000.0, 209 | rope_scaling=None, 210 | attention_bias=False, 211 | attention_dropout=0.0, 212 | model_parallel_size=1, 213 | swin_norm=False, 214 | vq_config=None, 215 | vocabulary_map=None, 216 | mlp_bias=False, 217 | mask_image_logits=True, 218 | dropout=0.0, 219 | **kwargs, 220 | ): 221 | self.vocab_size = vocab_size 222 | self.max_position_embeddings = max_position_embeddings 223 | self.hidden_size = hidden_size 224 | self.intermediate_size = intermediate_size 225 | self.num_hidden_layers = num_hidden_layers 226 | self.num_attention_heads = num_attention_heads 227 | self.mlp_bias = mlp_bias 228 | 229 | self.num_key_value_heads = num_key_value_heads 230 | self.hidden_act = hidden_act 231 | self.initializer_range = initializer_range 232 | self.rms_norm_eps = rms_norm_eps 233 | self.use_cache = use_cache 234 | self.rope_theta = rope_theta 235 | self.rope_scaling = rope_scaling 236 | self._rope_scaling_validation() 237 | self.attention_bias = attention_bias 238 | self.attention_dropout = attention_dropout 239 | self.model_parallel_size = model_parallel_size 240 | self.swin_norm = swin_norm 241 | self.mask_image_logits = mask_image_logits 242 | 243 | if vq_config is None: 244 | vq_config = {} 245 | logger.info("vq_config is None. initializing the ChameleonVQConfig with default values.") 246 | 247 | self.vq_config = ChameleonVQVAEConfig(**vq_config) 248 | 249 | self.vocabulary_map = vocabulary_map 250 | 251 | self.dropout = dropout 252 | 253 | super().__init__( 254 | pad_token_id=pad_token_id, 255 | bos_token_id=bos_token_id, 256 | eos_token_id=eos_token_id, 257 | tie_word_embeddings=tie_word_embeddings, 258 | **kwargs, 259 | ) 260 | 261 | def _rope_scaling_validation(self): 262 | """ 263 | Validate the `rope_scaling` configuration. 264 | """ 265 | if self.rope_scaling is None: 266 | return 267 | 268 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 269 | raise ValueError( 270 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 271 | f"got {self.rope_scaling}" 272 | ) 273 | rope_scaling_type = self.rope_scaling.get("type", None) 274 | rope_scaling_factor = self.rope_scaling.get("factor", None) 275 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 276 | raise ValueError( 277 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 278 | ) 279 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 280 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 281 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/chameleon/image_processing_chameleon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Image processor class for Chameleon.""" 16 | 17 | from typing import Dict, List, Optional, Union 18 | 19 | import numpy as np 20 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict 21 | from transformers.image_transforms import get_resize_output_image_size, resize, to_channel_dimension_format 22 | from transformers.image_utils import ( 23 | ChannelDimension, 24 | ImageInput, 25 | PILImageResampling, 26 | infer_channel_dimension_format, 27 | is_scaled_image, 28 | is_valid_image, 29 | to_numpy_array, 30 | valid_images, 31 | validate_kwargs, 32 | validate_preprocess_arguments, 33 | ) 34 | from transformers.utils import TensorType, is_vision_available, logging 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | if is_vision_available(): 39 | import PIL 40 | 41 | 42 | def make_batched_images(images) -> List[List[ImageInput]]: 43 | """ 44 | Accepts images in list or nested list format, and makes a list of images for preprocessing. 45 | 46 | Args: 47 | images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): 48 | The input image. 49 | 50 | Returns: 51 | list: A list of images. 52 | """ 53 | if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): 54 | return [img for img_list in images for img in img_list] 55 | 56 | elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): 57 | return images 58 | 59 | elif is_valid_image(images): 60 | return [images] 61 | 62 | raise ValueError(f"Could not make batched video from {images}") 63 | 64 | 65 | class ChameleonImageProcessor(BaseImageProcessor): 66 | r""" 67 | Constructs a Chameleon image processor. 68 | 69 | Args: 70 | do_resize (`bool`, *optional*, defaults to `True`): 71 | Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by 72 | `do_resize` in the `preprocess` method. 73 | size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 512}`): 74 | Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with 75 | the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` 76 | method. 77 | resample (`PILImageResampling`, *optional*, defaults to 1): 78 | Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. 79 | do_center_crop (`bool`, *optional*, defaults to `True`): 80 | Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the 81 | `preprocess` method. 82 | crop_size (`Dict[str, int]` *optional*, defaults to {"height": 512, "width": 512}): 83 | Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` 84 | method. 85 | do_rescale (`bool`, *optional*, defaults to `True`): 86 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in 87 | the `preprocess` method. 88 | rescale_factor (`int` or `float`, *optional*, defaults to 0.0078): 89 | Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` 90 | method. 91 | do_normalize (`bool`, *optional*, defaults to `True`): 92 | Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. 93 | image_mean (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`): 94 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of 95 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. 96 | image_std (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`): 97 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the 98 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. 99 | Can be overridden by the `image_std` parameter in the `preprocess` method. 100 | do_convert_rgb (`bool`, *optional*, defaults to `True`): 101 | Whether to convert the image to RGB. 102 | """ 103 | 104 | model_input_names = ["pixel_values"] 105 | 106 | def __init__( 107 | self, 108 | do_resize: bool = True, 109 | size: Dict[str, int] = None, 110 | resample: PILImageResampling = PIL.Image.LANCZOS, 111 | do_center_crop: bool = True, 112 | crop_size: Dict[str, int] = None, 113 | do_rescale: bool = True, 114 | rescale_factor: Union[int, float] = 0.0078, 115 | do_normalize: bool = True, 116 | image_mean: Optional[Union[float, List[float]]] = None, 117 | image_std: Optional[Union[float, List[float]]] = None, 118 | do_convert_rgb: bool = True, 119 | **kwargs, 120 | ) -> None: 121 | super().__init__(**kwargs) 122 | size = size if size is not None else {"shortest_edge": 512} 123 | size = get_size_dict(size, default_to_square=False) 124 | crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512} 125 | crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") 126 | 127 | self.do_resize = do_resize 128 | self.size = size 129 | self.resample = resample 130 | self.do_center_crop = do_center_crop 131 | self.crop_size = crop_size 132 | self.do_rescale = do_rescale 133 | self.rescale_factor = rescale_factor 134 | self.do_normalize = do_normalize 135 | self.image_mean = image_mean if image_mean is not None else [1.0, 1.0, 1.0] 136 | self.image_std = image_std if image_std is not None else [1.0, 1.0, 1.0] 137 | self.do_convert_rgb = do_convert_rgb 138 | self._valid_processor_keys = [ 139 | "images", 140 | "do_resize", 141 | "size", 142 | "resample", 143 | "do_center_crop", 144 | "crop_size", 145 | "do_rescale", 146 | "rescale_factor", 147 | "do_normalize", 148 | "image_mean", 149 | "image_std", 150 | "do_convert_rgb", 151 | "return_tensors", 152 | "data_format", 153 | "input_data_format", 154 | ] 155 | 156 | # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize 157 | def resize( 158 | self, 159 | image: np.ndarray, 160 | size: Dict[str, int], 161 | resample: PILImageResampling = PILImageResampling.BICUBIC, 162 | data_format: Optional[Union[str, ChannelDimension]] = None, 163 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 164 | **kwargs, 165 | ) -> np.ndarray: 166 | """ 167 | Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge 168 | resized to keep the input aspect ratio. 169 | 170 | Args: 171 | image (`np.ndarray`): 172 | Image to resize. 173 | size (`Dict[str, int]`): 174 | Size of the output image. 175 | resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): 176 | Resampling filter to use when resiizing the image. 177 | data_format (`str` or `ChannelDimension`, *optional*): 178 | The channel dimension format of the image. If not provided, it will be the same as the input image. 179 | input_data_format (`ChannelDimension` or `str`, *optional*): 180 | The channel dimension format of the input image. If not provided, it will be inferred. 181 | """ 182 | default_to_square = True 183 | if "shortest_edge" in size: 184 | size = size["shortest_edge"] 185 | default_to_square = False 186 | elif "height" in size and "width" in size: 187 | size = (size["height"], size["width"]) 188 | else: 189 | raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") 190 | 191 | output_size = get_resize_output_image_size( 192 | image, 193 | size=size, 194 | default_to_square=default_to_square, 195 | input_data_format=input_data_format, 196 | ) 197 | return resize( 198 | image, 199 | size=output_size, 200 | resample=resample, 201 | data_format=data_format, 202 | input_data_format=input_data_format, 203 | **kwargs, 204 | ) 205 | 206 | def preprocess( 207 | self, 208 | images: ImageInput, 209 | do_resize: bool = None, 210 | size: Dict[str, int] = None, 211 | resample: PILImageResampling = None, 212 | do_center_crop: bool = None, 213 | crop_size: int = None, 214 | do_rescale: bool = None, 215 | rescale_factor: float = None, 216 | do_normalize: bool = None, 217 | image_mean: Optional[Union[float, List[float]]] = None, 218 | image_std: Optional[Union[float, List[float]]] = None, 219 | do_convert_rgb: bool = None, 220 | return_tensors: Optional[Union[str, TensorType]] = None, 221 | data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, 222 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 223 | **kwargs, 224 | ) -> PIL.Image.Image: 225 | """ 226 | Preprocess an image or batch of images. 227 | 228 | Args: 229 | images (`ImageInput`): 230 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 231 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 232 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 233 | Whether to resize the image. 234 | size (`Dict[str, int]`, *optional*, defaults to `self.size`): 235 | Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with 236 | the longest edge resized to keep the input aspect ratio. 237 | resample (`int`, *optional*, defaults to `self.resample`): 238 | Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only 239 | has an effect if `do_resize` is set to `True`. 240 | do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): 241 | Whether to center crop the image. 242 | crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): 243 | Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. 244 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 245 | Whether to rescale the image. 246 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 247 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 248 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 249 | Whether to normalize the image. 250 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 251 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. 252 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 253 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to 254 | `True`. 255 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 256 | Whether to convert the image to RGB. 257 | return_tensors (`str` or `TensorType`, *optional*): 258 | The type of tensors to return. Can be one of: 259 | - Unset: Return a list of `np.ndarray`. 260 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 261 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 262 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 263 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 264 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): 265 | The channel dimension format for the output image. Can be one of: 266 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 267 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 268 | - Unset: Use the channel dimension format of the input image. 269 | input_data_format (`ChannelDimension` or `str`, *optional*): 270 | The channel dimension format for the input image. If unset, the channel dimension format is inferred 271 | from the input image. Can be one of: 272 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 273 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 274 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 275 | """ 276 | do_resize = do_resize if do_resize is not None else self.do_resize 277 | size = size if size is not None else self.size 278 | size = get_size_dict(size, param_name="size", default_to_square=False) 279 | resample = resample if resample is not None else self.resample 280 | do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop 281 | crop_size = crop_size if crop_size is not None else self.crop_size 282 | crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) 283 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 284 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 285 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 286 | image_mean = image_mean if image_mean is not None else self.image_mean 287 | image_std = image_std if image_std is not None else self.image_std 288 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 289 | 290 | validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) 291 | 292 | images = make_batched_images(images) 293 | 294 | if not valid_images(images): 295 | raise ValueError( 296 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 297 | "torch.Tensor, tf.Tensor or jax.ndarray." 298 | ) 299 | 300 | validate_preprocess_arguments( 301 | do_rescale=do_rescale, 302 | rescale_factor=rescale_factor, 303 | do_normalize=do_normalize, 304 | image_mean=image_mean, 305 | image_std=image_std, 306 | do_center_crop=do_center_crop, 307 | crop_size=crop_size, 308 | do_resize=do_resize, 309 | size=size, 310 | resample=resample, 311 | ) 312 | 313 | if do_convert_rgb: 314 | images = [self.blend_rgba(image) for image in images] 315 | 316 | # All transformations expect numpy arrays. 317 | images = [to_numpy_array(image) for image in images] 318 | 319 | if is_scaled_image(images[0]) and do_rescale: 320 | logger.warning_once( 321 | "It looks like you are trying to rescale already rescaled images. If the input" 322 | " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." 323 | ) 324 | 325 | if input_data_format is None: 326 | # We assume that all images have the same channel dimension format. 327 | input_data_format = infer_channel_dimension_format(images[0]) 328 | 329 | if do_resize: 330 | images = [ 331 | self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) 332 | for image in images 333 | ] 334 | 335 | if do_center_crop: 336 | images = [ 337 | self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images 338 | ] 339 | 340 | if do_rescale: 341 | images = [ 342 | self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images 343 | ] 344 | 345 | if do_normalize: 346 | images = [ 347 | self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) 348 | for image in images 349 | ] 350 | 351 | images = [ 352 | to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images 353 | ] 354 | 355 | data = {"pixel_values": images} 356 | return BatchFeature(data=data, tensor_type=return_tensors) 357 | 358 | def blend_rgba(self, image: ImageInput) -> ImageInput: 359 | """ 360 | Convert image to RGB by blending the transparency layer if it's in RGBA format. 361 | If image is not `PIL.Image`, it si simply returned without modifications. 362 | 363 | Args: 364 | image (`ImageInput`): 365 | Image to convert. 366 | """ 367 | 368 | if not isinstance(image, PIL.Image.Image): 369 | return image 370 | elif image.mode == "RGB": 371 | return image 372 | 373 | img_rgba = np.array(image.convert("RGBA")) 374 | 375 | # If there is no transparency layer, simple convert and return. 376 | if not (img_rgba[:, :, 3] < 255).any(): 377 | return image.convert("RGB") 378 | 379 | # There is a transparency layer, blend it with a white background. 380 | # Calculate the alpha proportion for blending. 381 | alpha = img_rgba[:, :, 3] / 255.0 382 | img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3] 383 | return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB") 384 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/chameleon/processing_chameleon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Processor class for Chameleon. 17 | """ 18 | 19 | from typing import List, Optional, Union 20 | 21 | from transformers.feature_extraction_utils import BatchFeature 22 | from transformers.image_utils import ImageInput 23 | from transformers.processing_utils import ProcessorMixin 24 | from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 25 | from transformers.utils import TensorType 26 | 27 | 28 | class ChameleonProcessor(ProcessorMixin): 29 | r""" 30 | Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single 31 | processor. 32 | 33 | [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`]. 34 | See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information. 35 | 36 | Args: 37 | image_processor ([`ChameleonImageProcessor`]): 38 | The image processor is a required input. 39 | tokenizer ([`LlamaTokenizerFast`]): 40 | The tokenizer is a required input. 41 | image_seq_length (`int`, *optional*, defaults to 1024): 42 | Sequence length of one image embedding. 43 | image_token (`str`, *optional*, defaults to `""`): 44 | The special token used to indicate image in the text. 45 | """ 46 | 47 | attributes = ["image_processor", "tokenizer"] 48 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") 49 | image_processor_class = "ChameleonImageProcessor" 50 | 51 | def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): 52 | self.image_seq_length = image_seq_length 53 | self.image_token = image_token 54 | self.image_start_token = "" # fixed tokens for start and end, so can hardcode 55 | self.image_end_token = "" 56 | super().__init__(image_processor, tokenizer) 57 | 58 | def __call__( 59 | self, 60 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 61 | images: ImageInput = None, 62 | padding: Union[bool, str, PaddingStrategy] = False, 63 | truncation: Union[bool, str, TruncationStrategy] = None, 64 | max_length: int = None, 65 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, 66 | return_for_text_completion: bool = False, 67 | ) -> BatchFeature: 68 | """ 69 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 70 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode 71 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to 72 | CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 73 | of the above two methods for more information. 74 | 75 | Args: 76 | text (`str`, `List[str]`, `List[List[str]]`): 77 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 78 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 79 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 80 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 81 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 82 | tensor. Both channels-first and channels-last formats are supported. 83 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): 84 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 85 | index) among: 86 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 87 | sequence if provided). 88 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 89 | acceptable input length for the model if that argument is not provided. 90 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 91 | lengths). 92 | max_length (`int`, *optional*): 93 | Maximum length of the returned list and optionally padding length (see above). 94 | truncation (`bool`, *optional*): 95 | Activates truncation to cut input sequences longer than `max_length` to `max_length`. 96 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 97 | If set, will return tensors of a particular framework. Acceptable values are: 98 | 99 | - `'tf'`: Return TensorFlow `tf.constant` objects. 100 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 101 | - `'np'`: Return NumPy `np.ndarray` objects. 102 | - `'jax'`: Return JAX `jnp.ndarray` objects. 103 | 104 | Returns: 105 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 106 | 107 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 108 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 109 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 110 | `None`). 111 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 112 | """ 113 | if isinstance(text, str): 114 | text = [text] 115 | elif not isinstance(text, list) and not isinstance(text[0], str): 116 | raise TypeError("Invalid input text. Please provide a string, or a list of strings") 117 | 118 | # Replace the image token with the expanded image token sequence 119 | prompt_strings = [] 120 | one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token 121 | for sample in text: 122 | sample = sample.replace(self.image_token, one_img_tokens) 123 | if not return_for_text_completion: 124 | sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode 125 | prompt_strings.append(sample) 126 | 127 | data = self.tokenizer( 128 | prompt_strings, 129 | return_tensors=return_tensors, 130 | padding=padding, 131 | truncation=truncation, 132 | max_length=max_length, 133 | ) 134 | 135 | if images is not None: 136 | pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] 137 | data["pixel_values"] = pixel_values 138 | 139 | return BatchFeature(data=data, tensor_type=return_tensors) 140 | 141 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama 142 | def batch_decode(self, *args, **kwargs): 143 | """ 144 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 145 | refer to the docstring of this method for more information. 146 | """ 147 | return self.tokenizer.batch_decode(*args, **kwargs) 148 | 149 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama 150 | def decode(self, *args, **kwargs): 151 | """ 152 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 153 | the docstring of this method for more information. 154 | """ 155 | return self.tokenizer.decode(*args, **kwargs) 156 | 157 | @property 158 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names 159 | def model_input_names(self): 160 | tokenizer_input_names = self.tokenizer.model_input_names 161 | image_processor_input_names = self.image_processor.model_input_names 162 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 163 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/chameleon_vae_ori/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_tokenizer import ImageTokenizer 2 | from .vocab import VocabInfo, VocabTranslation 3 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/chameleon_vae_ori/image_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # 3 | # This source code is licensed under the Chameleon License found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import PIL 7 | from PIL import Image 8 | import numpy as np 9 | import torch 10 | import yaml 11 | 12 | from .vqgan import VQModel 13 | 14 | 15 | class ImageTokenizer: 16 | def __init__( 17 | self, 18 | cfg_path: str, 19 | ckpt_path: str, 20 | device: str | torch.device | None = None, 21 | ): 22 | with open(cfg_path) as f: 23 | config = yaml.safe_load(f) 24 | 25 | params = config["model"]["params"] 26 | if "lossconfig" in params: 27 | del params["lossconfig"] 28 | params["ckpt_path"] = ckpt_path 29 | 30 | self._vq_model = VQModel(**params) 31 | self._vq_model.eval() 32 | 33 | if device is None: 34 | devices = {p.device for p in self._vq_model.parameters()} 35 | assert len(devices) == 1 36 | device = devices.pop() 37 | else: 38 | self._vq_model.to(device) 39 | self._device = device 40 | 41 | dtypes = {p.dtype for p in self._vq_model.parameters()} 42 | assert len(dtypes) == 1 43 | self._dtype = dtypes.pop() 44 | 45 | def _whiten_transparency(self, img: PIL.Image) -> PIL.Image: 46 | # Check if it's already in RGB format. 47 | if img.mode == "RGB": 48 | return img 49 | 50 | vals_rgba = np.array(img.convert("RGBA")) 51 | 52 | # If there is no transparency layer, simple convert and return. 53 | if not (vals_rgba[:, :, 3] < 255).any(): 54 | return img.convert("RGB") 55 | 56 | # There is a transparency layer, blend it with a white background. 57 | 58 | # Calculate the alpha proportion for blending. 59 | alpha = vals_rgba[:, :, 3] / 255.0 60 | # Blend with white background. 61 | vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * vals_rgba[:, :, :3] 62 | return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB") 63 | 64 | # def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor: 65 | # # Resize with aspect ratio preservation. 66 | # s = min(img.size) 67 | # scale = target_image_size / s 68 | # new_size = (round(scale * img.size[0]), round(scale * img.size[1])) 69 | # img = img.resize(new_size, PIL.Image.LANCZOS) 70 | # 71 | # # Center crop. 72 | # x0 = (img.width - target_image_size) // 2 73 | # y0 = (img.height - target_image_size) // 2 74 | # img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size)) 75 | # 76 | # # Convert to tensor. 77 | # np_img = np.array(img) / 255.0 # Normalize to [0, 1] 78 | # np_img = np_img * 2 - 1 # Scale to [-1, 1] 79 | # tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() # (Channels, Height, Width) format. 80 | # 81 | # # Add batch dimension. 82 | # return tensor_img.unsqueeze(0) 83 | 84 | def img_tokens_from_pil(self, img: PIL.Image) -> list[int]: 85 | img = self._whiten_transparency(img) 86 | # Convert to tensor. 87 | np_img = np.array(img) / 255.0 # Normalize to [0, 1] 88 | np_img = np_img * 2 - 1 # Scale to [-1, 1] 89 | img = torch.from_numpy(np_img).permute(2, 0, 1).to(self._vq_model.encoder.conv_in.weight) 90 | img = img.unsqueeze(0) 91 | 92 | _, _, [_, _, img_toks] = self._vq_model.encode(img) 93 | return img_toks 94 | 95 | def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image: 96 | # Ensure detachment and move tensor to CPU. 97 | detached_chw_tensor = chw_tensor.detach().cpu() 98 | 99 | # Normalize tensor to [0, 1] range from [-1, 1] range. 100 | normalized_chw_tensor = (torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0) / 2.0 101 | 102 | # Permute CHW tensor to HWC format and convert to NumPy array. 103 | hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy() 104 | 105 | # Convert to an 8-bit unsigned integer format. 106 | image_array_uint8 = (hwc_array * 255).astype(np.uint8) 107 | 108 | # Convert NumPy array to PIL Image. 109 | pil_image = Image.fromarray(image_array_uint8) 110 | 111 | # Convert image to RGB if it is not already. 112 | if pil_image.mode != "RGB": 113 | pil_image = pil_image.convert("RGB") 114 | 115 | return pil_image 116 | 117 | def pil_from_img_toks(self, tokens: torch.Tensor, h_latent_dim=32, w_latent_dim=32) -> PIL.Image: 118 | emb_dim = self._vq_model.quantize.embedding.weight.shape[-1] 119 | codebook_entry = self._vq_model.quantize.get_codebook_entry(tokens, (1, h_latent_dim, w_latent_dim, emb_dim)) 120 | pixels = self._vq_model.decode(codebook_entry) 121 | return self._pil_from_chw_tensor(pixels[0]) 122 | 123 | def latent_embedding_from_pil(self, img: PIL.Image): 124 | img = self._whiten_transparency(img) 125 | 126 | # Convert to tensor. 127 | np_img = np.array(img) / 255.0 # Normalize to [0, 1] 128 | np_img = np_img * 2 - 1 # Scale to [-1, 1] 129 | img = torch.from_numpy(np_img).permute(2, 0, 1) # (Channels, Height, Width) format. 130 | img = img.unsqueeze(0).to(self._vq_model.encoder.conv_in.weight) 131 | latent_embedding, _, _ = self._vq_model.encode(img) 132 | return latent_embedding 133 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/chameleon_vae_ori/vocab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Chameleon License found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import cached_property 7 | 8 | import torch 9 | 10 | 11 | class VocabInfo: 12 | def __init__(self, vocab_map: dict[str, int]): 13 | self.name2val = vocab_map 14 | 15 | self.bos_id = vocab_map.get("") 16 | self.eos_id = vocab_map.get("") 17 | self.boi_id = vocab_map.get("") 18 | self.eoi_id = vocab_map.get("") 19 | self.pad_id = vocab_map.get("") 20 | self.eot_id = vocab_map.get("") 21 | 22 | @property 23 | def begin_sequence(self) -> int: 24 | return self.bos_id 25 | 26 | @property 27 | def end_sequence(self) -> int: 28 | return self.eos_id 29 | 30 | @property 31 | def begin_image(self) -> int: 32 | return self.boi_id 33 | 34 | @property 35 | def end_image(self) -> int: 36 | return self.eoi_id 37 | 38 | @property 39 | def padding(self) -> int: 40 | return self.pad_id 41 | 42 | @property 43 | def end_turn(self) -> int: 44 | return self.eot_id 45 | 46 | @cached_property 47 | def val2name(self) -> dict[int, str]: 48 | return {v: k for k, v in self.name2val.items()} 49 | 50 | @cached_property 51 | def all_tokens(self) -> list[int]: 52 | return sorted(self.name2val.values()) 53 | 54 | @cached_property 55 | def image_tokens(self) -> list[int]: 56 | return sorted([val for name, val in self.name2val.items() if name.startswith("IMGIMG")]) 57 | 58 | @cached_property 59 | def special_tokens(self) -> list[int]: 60 | return sorted([val for name, val in self.name2val.items() if name.startswith("<") and name != "<"]) 61 | 62 | @cached_property 63 | def text_tokens(self) -> list[int]: 64 | return sorted(set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens)) 65 | 66 | 67 | class VocabTranslation: 68 | def __init__(self, vocab_info: VocabInfo, device: str | None = None): 69 | self._vocab = vocab_info 70 | self._device = device 71 | 72 | @cached_property 73 | def bpe2img(self) -> dict[int, int]: 74 | img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} 75 | 76 | def remap(old_name: str) -> str: 77 | return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) 78 | 79 | return {tok: int(remap(self._vocab.val2name[tok])) for tok in self._vocab.image_tokens} 80 | 81 | @cached_property 82 | def img2bpe(self) -> dict[int, int]: 83 | return {v: k for k, v in self.bpe2img.items()} 84 | 85 | @cached_property 86 | def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: 87 | sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device) 88 | sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device) 89 | return sorted_bpe, sorted_img 90 | 91 | @cached_property 92 | def img2bpe_mapping_tensor(self) -> torch.LongTensor: 93 | mapping = torch.zeros( 94 | max(self.img2bpe.keys()) + 1, 95 | dtype=torch.int, 96 | device=self._device, 97 | ) 98 | for k, v in self.img2bpe.items(): 99 | mapping[k] = v 100 | return mapping 101 | 102 | def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor: 103 | bpe_tok, img_tok = self.bpe2img_search_tensors 104 | return img_tok[torch.searchsorted(bpe_tok, bpe_batch)] 105 | 106 | def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor: 107 | return self.img2bpe_mapping_tensor[img_batch] 108 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/configuration_xllmx_chameleon.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from .chameleon import ChameleonConfig 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class ChameleonXLLMXConfig(ChameleonConfig): 10 | 11 | def __init__( 12 | self, 13 | z_loss_weight: float = 0.0, 14 | **kwargs, 15 | ): 16 | self.z_loss_weight = z_loss_weight 17 | super().__init__( 18 | **kwargs, 19 | ) 20 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/model/modeling_xllmx_chameleon.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import math 4 | from typing import List 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from .chameleon import ChameleonForConditionalGeneration 10 | from .configuration_xllmx_chameleon import ChameleonXLLMXConfig 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) 15 | 16 | 17 | __all__ = ["ChameleonXLLMXForConditionalGeneration"] 18 | 19 | 20 | class ChameleonXLLMXForConditionalGeneration(ChameleonForConditionalGeneration): 21 | config_class = ChameleonXLLMXConfig 22 | 23 | def __init__(self, config): 24 | super().__init__(config) 25 | 26 | def forward(self, input_ids=None, labels=None, training=True, **kwargs): 27 | 28 | max_tokens = max([len(_) for _ in input_ids]) 29 | max_tokens = min(max_tokens, self.config.max_position_embeddings) 30 | input_ids = [_[:max_tokens] for _ in input_ids] 31 | labels = [_[:max_tokens] for _ in labels] 32 | 33 | input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] 34 | input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) 35 | 36 | labels = [label + [-100] * (max_tokens - len(label)) for label in labels] 37 | labels = torch.tensor(labels, dtype=torch.int64, device=self.device) 38 | 39 | # explicit use_cache=False for the following 40 | # https://github.com/Lightning-AI/pytorch-lightning/issues/19267 41 | result = ChameleonForConditionalGeneration.forward( 42 | self, input_ids=input_ids, labels=labels, use_cache=False, **kwargs 43 | ) 44 | 45 | c_loss = result[0] 46 | 47 | additional_loss_dict = {} 48 | if self.config.z_loss_weight > 0: 49 | logits: torch.Tensor = result[1] 50 | valid_mask = labels >= 0 51 | z_loss = torch.logsumexp(logits, dim=-1).pow(2)[valid_mask].mean() 52 | additional_loss_dict["z_loss"] = (z_loss, self.config.z_loss_weight) 53 | return c_loss, additional_loss_dict 54 | 55 | def get_fsdp_wrap_module_list(self) -> List: 56 | modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens] 57 | if hasattr(self.model, "vqmodel"): # may be deleted 58 | modules.append(self.model.vqmodel) 59 | return modules 60 | 61 | def get_checkpointing_wrap_module_list(self) -> List: 62 | modules = [ 63 | *list(self.model.layers), 64 | ] 65 | return modules 66 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/pre_tokenize/Adobe_5k_pre_tokenize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | 6 | from argparse import ArgumentParser 7 | import json 8 | import math 9 | import pickle 10 | import random 11 | 12 | from data.convertsation import Conversation 13 | from data.item_processor import FlexARItemProcessor 14 | 15 | 16 | class ItemProcessor(FlexARItemProcessor): 17 | def __init__( 18 | self, 19 | tokenizer="../Lumina-mGPT-7B-768-tokenizer", 20 | conv_template=Conversation, 21 | target_size=512, 22 | ): 23 | super().__init__(tokenizer, conv_template, target_size) 24 | print(self.crop_size_list) 25 | self.data_root = "/Dataset/Adobe_5k" #Path to Adobe_5k Daraset 26 | 27 | def process_item(self, raw_item, training_mode=False, out_flatten=True): 28 | edit_tasks = [] 29 | Image_A = os.path.join(self.data_root, raw_item["input_img_A"]) 30 | Image_B = os.path.join(self.data_root, raw_item["output_img_B"]) 31 | Image_C = os.path.join(self.data_root, raw_item["output_img_C"]) 32 | Image_D = os.path.join(self.data_root, raw_item["output_img_D"]) 33 | Image_E = os.path.join(self.data_root, raw_item["output_img_E"]) 34 | Image_F = os.path.join(self.data_root, raw_item["output_img_F"]) 35 | 36 | edit_task_1 = { 37 | "conversations": [ 38 | {"from": "human", raw_item['Task_Descriptions_from_A_to_B'] + "value": "<|image|>"}, 39 | {"from": "gpt", "value": "<|image|>"} 40 | ], 41 | "image": [Image_A, Image_B] 42 | } 43 | edit_tasks.append(edit_task_1) 44 | 45 | edit_task_2 = { 46 | "conversations": [ 47 | {"from": "human", raw_item['Task_Descriptions_from_B_to_A'] + "value": "<|image|>"}, 48 | {"from": "gpt", "value": "<|image|>"} 49 | ], 50 | "image": [Image_B, Image_A] 51 | } 52 | edit_tasks.append(edit_task_2) 53 | 54 | edit_task_3 = { 55 | "conversations": [ 56 | {"from": "human", raw_item['Task_Descriptions_from_A_to_C'] + "value": "<|image|>"}, 57 | {"from": "gpt", "value": "<|image|>"} 58 | ], 59 | "image": [Image_A, Image_C] 60 | } 61 | edit_tasks.append(edit_task_3) 62 | 63 | edit_task_4 = { 64 | "conversations": [ 65 | {"from": "human", raw_item['Task_Descriptions_from_C_to_A'] + "value": "<|image|>"}, 66 | {"from": "gpt", "value": "<|image|>"} 67 | ], 68 | "image": [Image_C, Image_A] 69 | } 70 | edit_tasks.append(edit_task_4) 71 | 72 | edit_task_5 = { 73 | "conversations": [ 74 | {"from": "human", raw_item['Task_Descriptions_from_A_to_D'] + "value": "<|image|>"}, 75 | {"from": "gpt", "value": "<|image|>"} 76 | ], 77 | "image": [Image_A, Image_D] 78 | } 79 | edit_tasks.append(edit_task_5) 80 | 81 | edit_task_6 = { 82 | "conversations": [ 83 | {"from": "human", raw_item['Task_Descriptions_from_D_to_A'] + "value": "<|image|>"}, 84 | {"from": "gpt", "value": "<|image|>"} 85 | ], 86 | "image": [Image_D, Image_A] 87 | } 88 | edit_tasks.append(edit_task_6) 89 | 90 | edit_task_7 = { 91 | "conversations": [ 92 | {"from": "human", raw_item['Task_Descriptions_from_A_to_E'] + "value": "<|image|>"}, 93 | {"from": "gpt", "value": "<|image|>"} 94 | ], 95 | "image": [Image_A, Image_E] 96 | } 97 | edit_tasks.append(edit_task_7) 98 | 99 | edit_task_8 = { 100 | "conversations": [ 101 | {"from": "human", raw_item['Task_Descriptions_from_E_to_A'] + "value": "<|image|>"}, 102 | {"from": "gpt", "value": "<|image|>"} 103 | ], 104 | "image": [Image_E, Image_A] 105 | } 106 | edit_tasks.append(edit_task_8) 107 | 108 | edit_task_9 = { 109 | "conversations": [ 110 | {"from": "human", raw_item['Task_Descriptions_from_A_to_F'] + "value": "<|image|>"}, 111 | {"from": "gpt", "value": "<|image|>"} 112 | ], 113 | "image": [Image_A, Image_F] 114 | } 115 | edit_tasks.append(edit_task_9) 116 | 117 | edit_task_10 = { 118 | "conversations": [ 119 | {"from": "human", raw_item['Task_Descriptions_from_F_to_A'] + "value": "<|image|>"}, 120 | {"from": "gpt", "value": "<|image|>"} 121 | ], 122 | "image": [Image_F, Image_A] 123 | } 124 | edit_tasks.append(edit_task_10) 125 | 126 | task_list = [] 127 | for task in edit_tasks: 128 | result = super(ItemProcessor, self).process_item(task, training_mode, out_flatten) 129 | task_list.append(result) 130 | 131 | return task_list 132 | 133 | if __name__ == "__main__": 134 | 135 | parser = ArgumentParser() 136 | parser.add_argument( 137 | "--splits", 138 | type=int, 139 | default=8, 140 | ) 141 | parser.add_argument( 142 | "--rank", 143 | type=int, 144 | default=0, 145 | ) 146 | parser.add_argument( 147 | "--in_filename", 148 | type=str, 149 | ) 150 | parser.add_argument( 151 | "--out_dir", 152 | type=str, 153 | ) 154 | parser.add_argument("--target_size", type=int, default=512) 155 | args = parser.parse_args() 156 | 157 | item_processor = ItemProcessor(target_size=args.target_size) 158 | 159 | with open(args.in_filename) as f: 160 | ori_contents = json.load(f) 161 | 162 | per_max_len = 10 163 | 164 | num = len(ori_contents) 165 | 166 | splits = args.splits 167 | rank = args.rank 168 | output_dir = args.out_dir 169 | save_dir = os.path.join(output_dir, "files") 170 | os.makedirs(save_dir, exist_ok=True) 171 | 172 | num_per_rank = math.ceil(num / splits) 173 | 174 | try: 175 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "r") as f: 176 | start_idx = int(f.read()) + 1 177 | print(f"resume from {start_idx}") 178 | except: 179 | start_idx = num_per_rank * rank 180 | print(f"start from {start_idx}") 181 | 182 | end_idx = min(num_per_rank * (rank + 1), len(ori_contents)) 183 | for i in range(start_idx, end_idx): 184 | if i % 10 == 0: 185 | print(f"{i}/{end_idx}") 186 | 187 | records = [] 188 | try: 189 | processed_items = item_processor.process_item(ori_contents[i], training_mode=True) 190 | 191 | for j, (tokens, labels) in enumerate(processed_items): 192 | pkl_path = os.path.join(save_dir, f"{i}_{j}.pkl") 193 | new_item = {"token": tokens, "label": labels, "id": i * per_max_len + j} 194 | with open(pkl_path, "wb") as f: 195 | pickle.dump(new_item, f) 196 | record = {"file": pkl_path, "len": len(tokens), "id": i * per_max_len + j} 197 | records.append(record) 198 | 199 | except Exception as e: 200 | from traceback import format_exc 201 | 202 | print(f"item {i} error: \n{ori_contents[i]}") 203 | print(format_exc()) 204 | 205 | if records: 206 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-record.jsonl"), "a") as f: 207 | for record in records: 208 | f.write(json.dumps(record) + "\n") 209 | 210 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "w") as f: 211 | if i == end_idx - 1: 212 | f.write("finished") 213 | else: 214 | f.write(f"{i}") 215 | 216 | 217 | ''' 218 | for i in {0..7} 219 | do 220 | export CUDA_VISIBLE_DEVICES=${i} 221 | python -u pre_tokenize/Adobe_5k_pre_tokenize.py \ 222 | --splits=8 \ 223 | --rank=${i} \ 224 | --in_filename /Dataset/Adobe_5k/Adobe_5k_data.json \ 225 | --out_dir ./json/edit_resolution_448/Adobe_5k/ \ 226 | --target_size 448 &> ${i}.log & 227 | done 228 | 229 | python -u pre_tokenize/concat_record.py --sub_record_dir ./json/edit_resolution_448/Adobe_5k/ --save_path ./json/edit_resolution_448/Adobe_5k/record.json 230 | ''' -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/pre_tokenize/allweather_pre_tokenize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | 6 | from argparse import ArgumentParser 7 | import json 8 | import math 9 | import pickle 10 | import random 11 | 12 | from data.convertsation import Conversation 13 | from data.item_processor import FlexARItemProcessor 14 | 15 | 16 | class ItemProcessor(FlexARItemProcessor): 17 | def __init__( 18 | self, 19 | tokenizer="../Lumina-mGPT-7B-768-tokenizer", 20 | conv_template=Conversation, 21 | target_size=512, 22 | ): 23 | super().__init__(tokenizer, conv_template, target_size) 24 | print(self.crop_size_list) 25 | self.data_root = "/Dataset/Allweather" #Path to Allweather Dataset 26 | 27 | def process_item(self, raw_item, training_mode=False, out_flatten=True): 28 | edit_tasks = [] 29 | Image_A = os.path.join(self.data_root, raw_item["source_img_A"]) 30 | Image_B = os.path.join(self.data_root, raw_item["target_img_B"]) 31 | 32 | edit_task_1 = { 33 | "conversations": [ 34 | {"from": "human", raw_item['Task_Descriptions_from_A_to_B'] + "value": "<|image|>"}, 35 | {"from": "gpt", "value": "<|image|>"} 36 | ], 37 | "image": [Image_A, Image_B] 38 | } 39 | edit_tasks.append(edit_task_1) 40 | 41 | edit_task_2 = { 42 | "conversations": [ 43 | {"from": "human", raw_item['Task_Descriptions_from_B_to_A'] + "value": "<|image|>"}, 44 | {"from": "gpt", "value": "<|image|>"} 45 | ], 46 | "image": [Image_B, Image_A] 47 | } 48 | edit_tasks.append(edit_task_2) 49 | 50 | task_list = [] 51 | for task in edit_tasks: 52 | # print(task) 53 | result = super(ItemProcessor, self).process_item(task, training_mode, out_flatten) 54 | task_list.append(result) 55 | 56 | return task_list 57 | 58 | if __name__ == "__main__": 59 | 60 | parser = ArgumentParser() 61 | parser.add_argument( 62 | "--splits", 63 | type=int, 64 | default=8, 65 | ) 66 | parser.add_argument( 67 | "--rank", 68 | type=int, 69 | default=0, 70 | ) 71 | parser.add_argument( 72 | "--in_filename", 73 | type=str, 74 | ) 75 | parser.add_argument( 76 | "--out_dir", 77 | type=str, 78 | ) 79 | parser.add_argument("--target_size", type=int, default=512) 80 | args = parser.parse_args() 81 | 82 | item_processor = ItemProcessor(target_size=args.target_size) 83 | 84 | with open(args.in_filename) as f: 85 | ori_contents = json.load(f) 86 | 87 | per_max_len = 2 88 | 89 | num = len(ori_contents) 90 | 91 | splits = args.splits 92 | rank = args.rank 93 | output_dir = args.out_dir 94 | save_dir = os.path.join(output_dir, "files") 95 | os.makedirs(save_dir, exist_ok=True) 96 | 97 | num_per_rank = math.ceil(num / splits) 98 | 99 | try: 100 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "r") as f: 101 | start_idx = int(f.read()) + 1 102 | print(f"resume from {start_idx}") 103 | except: 104 | start_idx = num_per_rank * rank 105 | print(f"start from {start_idx}") 106 | 107 | end_idx = min(num_per_rank * (rank + 1), len(ori_contents)) 108 | for i in range(start_idx, end_idx): 109 | if i % 10 == 0: 110 | print(f"{i}/{end_idx}") 111 | 112 | records = [] 113 | try: 114 | processed_items = item_processor.process_item(ori_contents[i], training_mode=True) 115 | 116 | for j, (tokens, labels) in enumerate(processed_items): 117 | pkl_path = os.path.join(save_dir, f"{i}_{j}.pkl") 118 | new_item = {"token": tokens, "label": labels, "id": i * per_max_len + j} 119 | with open(pkl_path, "wb") as f: 120 | pickle.dump(new_item, f) 121 | record = {"file": pkl_path, "len": len(tokens), "id": i * per_max_len + j} 122 | records.append(record) 123 | 124 | except Exception as e: 125 | from traceback import format_exc 126 | 127 | print(f"item {i} error: \n{ori_contents[i]}") 128 | print(format_exc()) 129 | 130 | if records: 131 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-record.jsonl"), "a") as f: 132 | for record in records: 133 | f.write(json.dumps(record) + "\n") 134 | 135 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "w") as f: 136 | if i == end_idx - 1: 137 | f.write("finished") 138 | else: 139 | f.write(f"{i}") 140 | 141 | 142 | ''' 143 | for i in {0..7} 144 | do 145 | export CUDA_VISIBLE_DEVICES=${i} 146 | python -u pre_tokenize/allweather_pre_tokenize.py \ 147 | --splits=8 \ 148 | --rank=${i} \ 149 | --in_filename /Dataset/Allweather/allweather_data.json \ 150 | --out_dir ./json/edit_resolution_448/Allweather/ \ 151 | --target_size 448 &> ${i}.log & 152 | done 153 | 154 | python -u pre_tokenize/concat_record.py --sub_record_dir ./json/edit_resolution_448/Allweather --save_path ./json/edit_resolution_448/Allweather/record.json 155 | ''' -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/pre_tokenize/concat_record.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import json 3 | import os 4 | import re 5 | import warnings 6 | 7 | 8 | def find_sub_records(directory: str): 9 | pattern = re.compile(r"\d+-of-\d+-record\.json(l)?") 10 | 11 | sub_record_files = [f for f in os.listdir(directory) if pattern.match(f)] 12 | sorted_files = sorted(sub_record_files, key=lambda filename: int(filename.split("-of")[0])) 13 | return sorted_files 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = ArgumentParser() 18 | parser.add_argument( 19 | "--sub_record_dir", 20 | type=str, 21 | default=None, 22 | ) 23 | parser.add_argument( 24 | "--save_path", 25 | type=str, 26 | default=None, 27 | ) 28 | args = parser.parse_args() 29 | 30 | l_sub_records = find_sub_records(args.sub_record_dir) 31 | 32 | print(f"find {len(l_sub_records)} sub-records in {args.sub_record_dir}") 33 | print(str(l_sub_records) + "\n\n") 34 | 35 | complete_record = [] 36 | for sub_record in l_sub_records: 37 | with open(os.path.join(args.sub_record_dir, sub_record)) as f: 38 | lines = f.readlines() 39 | for i, l in enumerate(lines): 40 | try: 41 | l_item = json.loads(l) 42 | complete_record.append(l_item) 43 | except: 44 | if i == len(lines) - 1: 45 | print(f"{sub_record} seems still writing, skip last incomplete record") 46 | else: 47 | warnings.warn(f"read line failed: {l}") 48 | 49 | with open(args.save_path, "w") as f: 50 | json.dump(complete_record, f) 51 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/pre_tokenize/seed_multi_turn_pre_tokenize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | 6 | from argparse import ArgumentParser 7 | import json 8 | import math 9 | import pickle 10 | import random 11 | 12 | from data.convertsation import Conversation 13 | from data.item_processor import FlexARItemProcessor 14 | 15 | 16 | class ItemProcessor(FlexARItemProcessor): 17 | def __init__( 18 | self, 19 | tokenizer="../Lumina-mGPT-7B-768-tokenizer", 20 | conv_template=Conversation, 21 | target_size=512, 22 | ): 23 | super().__init__(tokenizer, conv_template, target_size) 24 | print(self.crop_size_list) 25 | self.data_root = "/Dataset/SEED/multi_turn_editing/data" #Path to SEED Dataset 26 | 27 | def process_item(self, raw_item, training_mode=False, out_flatten=True): 28 | def ensure_period_at_end(text): 29 | if not text.endswith('.'): 30 | return text + '.' 31 | return text 32 | 33 | Image_Paths = [os.path.join(self.data_root, raw_item[f"edit_image_{i+1}"]) for i in range((len(raw_item) - 1)//3)] 34 | Image_Paths.insert(0, os.path.join(self.data_root, raw_item["source_image"])) 35 | 36 | instructions = [ensure_period_at_end(raw_item[f"instruction_{i+1}"]) for i in range((len(raw_item) - 1) // 3)] 37 | inverse_instructions = [ensure_period_at_end(raw_item[f"inverse_instruction_{i+1}"]) for i in range((len(raw_item) - 1) // 3)] 38 | 39 | edit_tasks = [] 40 | 41 | for i in range(len(instructions)): 42 | for j in range(i, len(instructions)): 43 | combined_instruction = ' '.join(instructions[i:j+1]) 44 | source_image = Image_Paths[i] 45 | target_image = Image_Paths[j+1] 46 | edit_task = { 47 | "conversations": [ 48 | {"from": "human", combined_instruction + "value": "<|image|>"}, 49 | {"from": "gpt", "value": "<|image|>"} 50 | ], 51 | "image": [source_image, target_image] 52 | } 53 | edit_tasks.append(edit_task) 54 | 55 | for i in range(len(inverse_instructions)): 56 | for j in range(i, len(inverse_instructions)): 57 | combined_inverse_instruction = ' '.join(inverse_instructions[j:i-1:-1] if i > 0 else inverse_instructions[j::-1]) 58 | source_image = Image_Paths[-(j+2)] 59 | target_image = Image_Paths[-(i+1)] 60 | edit_task = { 61 | "conversations": [ 62 | {"from": "human", combined_inverse_instruction + "value": "<|image|>"}, 63 | {"from": "gpt", "value": "<|image|>"} 64 | ], 65 | "image": [source_image, target_image] 66 | } 67 | edit_tasks.append(edit_task) 68 | 69 | task_list = [] 70 | for task in edit_tasks: 71 | # print(task) 72 | result = super(ItemProcessor, self).process_item(task, training_mode, out_flatten) 73 | task_list.append(result) 74 | 75 | return task_list 76 | 77 | if __name__ == "__main__": 78 | 79 | parser = ArgumentParser() 80 | parser.add_argument( 81 | "--splits", 82 | type=int, 83 | default=8, 84 | ) 85 | parser.add_argument( 86 | "--rank", 87 | type=int, 88 | default=0, 89 | ) 90 | parser.add_argument( 91 | "--in_filename", 92 | type=str, 93 | ) 94 | parser.add_argument( 95 | "--out_dir", 96 | type=str, 97 | ) 98 | parser.add_argument("--target_size", type=int, default=512) 99 | args = parser.parse_args() 100 | 101 | item_processor = ItemProcessor(target_size=args.target_size) 102 | 103 | with open(args.in_filename) as f: 104 | ori_contents = json.load(f) 105 | 106 | per_max_len = 0 107 | for i in range(len(ori_contents)): 108 | per_max_len = max(per_max_len, (len(ori_contents[i]) - 1) // 3) 109 | print(per_max_len) 110 | per_max_len = (per_max_len + 1) * per_max_len 111 | print(per_max_len) 112 | 113 | num = len(ori_contents) 114 | 115 | splits = args.splits 116 | rank = args.rank 117 | output_dir = args.out_dir 118 | save_dir = os.path.join(output_dir, "files") 119 | os.makedirs(save_dir, exist_ok=True) 120 | 121 | num_per_rank = math.ceil(num / splits) 122 | 123 | try: 124 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "r") as f: 125 | start_idx = int(f.read()) + 1 126 | print(f"resume from {start_idx}") 127 | except: 128 | start_idx = num_per_rank * rank 129 | print(f"start from {start_idx}") 130 | 131 | end_idx = min(num_per_rank * (rank + 1), len(ori_contents)) 132 | for i in range(start_idx, end_idx): 133 | if i % 10 == 0: 134 | print(f"{i}/{end_idx}") 135 | 136 | records = [] 137 | try: 138 | processed_items = item_processor.process_item(ori_contents[i], training_mode=True) 139 | 140 | for j, (tokens, labels) in enumerate(processed_items): 141 | pkl_path = os.path.join(save_dir, f"{i}_{j}.pkl") 142 | new_item = {"token": tokens, "label": labels, "id": i * per_max_len + j} 143 | with open(pkl_path, "wb") as f: 144 | pickle.dump(new_item, f) 145 | record = {"file": pkl_path, "len": len(tokens), "id": i * per_max_len + j} 146 | records.append(record) 147 | 148 | except Exception as e: 149 | from traceback import format_exc 150 | 151 | print(f"item {i} error: \n{ori_contents[i]}") 152 | print(format_exc()) 153 | 154 | if records: 155 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-record.jsonl"), "a") as f: 156 | for record in records: 157 | f.write(json.dumps(record) + "\n") 158 | 159 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "w") as f: 160 | if i == end_idx - 1: 161 | f.write("finished") 162 | else: 163 | f.write(f"{i}") 164 | 165 | 166 | ''' 167 | for i in {0..7} 168 | do 169 | export CUDA_VISIBLE_DEVICES=${i} 170 | python -u pre_tokenize/seed_multi_turn_pre_tokenize.py \ 171 | --splits=8 \ 172 | --rank=${i} \ 173 | --in_filename /Dataset/seed_multi_turn_editing_data.json \ 174 | --out_dir ./json/edit_resolution_448/SEED/Multi_Turn \ 175 | --target_size 448 &> ${i}.log & 176 | done 177 | 178 | python -u pre_tokenize/concat_record.py --sub_record_dir ./json/edit_resolution_448/SEED/Multi_Turn/ --save_path ./json/edit_resolution_448/SEED/Multi_Turn/record.json 179 | ''' 180 | 181 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/xllmx/__init__.py -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/xllmx/data/__init__.py -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/data/conversation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/xllmx/data/conversation/__init__.py -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/data/conversation/template.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class ConversationBase: 5 | roles = ["Human", "Assistant"] 6 | 7 | def __init__(self, messages=None): 8 | self.messages = messages or [] 9 | 10 | def process(self): 11 | raise NotImplementedError 12 | 13 | def get_prompt(self): 14 | return self.process()["conv"] 15 | 16 | def append_message(self, role, message): 17 | self.messages.append([role, message]) 18 | 19 | def copy(self): 20 | return ConversationBase( 21 | messages=[[x, y] for x, y in self.messages], 22 | ) 23 | 24 | def load_qas(self, qas: List[List[str]]): 25 | self.messages = [] 26 | for q, a in qas: 27 | self.append_message(self.roles[0], q) 28 | self.append_message(self.roles[1], a) 29 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/data/data_reader.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import logging 3 | import time 4 | from typing import Union 5 | 6 | from PIL import Image 7 | 8 | Image.MAX_IMAGE_PIXELS = None 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def read_general(path) -> Union[str, BytesIO]: 13 | if "s3://" in path: 14 | init_ceph_client_if_needed() 15 | file_bytes = BytesIO(client.get(path)) 16 | return file_bytes 17 | else: 18 | return path 19 | 20 | 21 | def init_ceph_client_if_needed(): 22 | global client 23 | if client is None: 24 | logger.info(f"initializing ceph client ...") 25 | st = time.time() 26 | from petrel_client.client import Client # noqa 27 | 28 | client = Client("/path/to/petreloss.conf") 29 | ed = time.time() 30 | logger.info(f"initialize client cost {ed - st:.2f} s") 31 | 32 | 33 | client = None 34 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/data/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import os 5 | from pathlib import Path 6 | import pickle 7 | from time import sleep 8 | import traceback 9 | import warnings 10 | 11 | import h5py 12 | import torch 13 | import torch.distributed as dist 14 | from torch.utils.data import Dataset 15 | import yaml 16 | 17 | from .item_processor import ItemProcessorBase 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class FinetuneConversationDataset(Dataset): 23 | def __init__(self, config_path, item_processor: ItemProcessorBase, cache_on_disk=False): 24 | 25 | self.item_processor = item_processor 26 | 27 | logger.info(f"read dataset config from {config_path}") 28 | with open(config_path, "r") as f: 29 | self.config = yaml.load(f, Loader=yaml.FullLoader) 30 | logger.info("DATASET CONFIG:") 31 | logger.info(self.config) 32 | 33 | self.cache_on_disk = cache_on_disk 34 | if self.cache_on_disk: 35 | cache_dir = self._get_cache_dir(config_path) 36 | if dist.get_rank() == 0: 37 | self._collect_annotations_and_save_to_cache(cache_dir) 38 | dist.barrier() 39 | self.meta_collection, self.annotations_collection = self._load_annotations_from_cache(cache_dir) 40 | else: 41 | cache_dir = None 42 | self.meta_collection, self.annotations_collection = self._collect_annotations() 43 | 44 | def __len__(self): 45 | return sum([_["len"] for _ in self.meta_collection]) 46 | 47 | def _collect_annotations(self): 48 | meta_collection = [] 49 | annotations_collection = [] 50 | 51 | for meta in self.config["META"]: 52 | meta, annotations = self._load_meta(meta) 53 | meta_collection.append(meta) 54 | annotations_collection.append(annotations) 55 | 56 | return meta_collection, annotations_collection 57 | 58 | def _load_meta(self, meta): 59 | if "type" not in meta: 60 | meta["type"] = "default" 61 | 62 | meta_path, meta_type = meta["path"], meta["type"] 63 | meta_ext = os.path.splitext(meta_path)[-1] 64 | if meta_ext == ".json": 65 | with open(meta_path) as f: 66 | annotations = json.load(f) 67 | elif meta_ext == ".jsonl": 68 | annotations = [] 69 | with open(meta_path) as f: 70 | for i, line in enumerate(f): 71 | try: 72 | annotations.append(json.loads(line)) 73 | except json.decoder.JSONDecodeError as e: 74 | logger.error(f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}") 75 | raise e 76 | elif meta_ext == ".pkl": 77 | with open(meta_path, "rb") as f: 78 | annotations = pickle.load(f) 79 | assert isinstance(annotations, list) 80 | elif meta_ext == ".pth": 81 | annotations = torch.load(meta_path) 82 | assert isinstance(annotations, list) 83 | else: 84 | raise NotImplementedError( 85 | f'Unknown meta file extension: "{meta_ext}". ' 86 | f"Currently, .json, .jsonl are supported. " 87 | "If you are using a supported format, please set the file extension so that the proper parsing " 88 | "routine can be called." 89 | ) 90 | logger.info(f"{meta_path}, type{meta_type}: len {len(annotations)}") 91 | 92 | meta["len"] = len(annotations) 93 | 94 | meta["item_len_list"] = [self.item_processor.predict_item_token_length(_) for _ in annotations] 95 | 96 | return meta, annotations 97 | 98 | def _collect_annotations_and_save_to_cache(self, cache_dir): 99 | if (Path(cache_dir) / "data.h5").exists() and (Path(cache_dir) / "ready").exists(): 100 | # off-the-shelf annotation cache exists 101 | warnings.warn( 102 | f"Use existing h5 data cache: {Path(cache_dir)}\n" 103 | f"Note: if the actual data defined by the data config has changed since your last run, " 104 | f"please delete the cache manually and re-run this experiment, or the data actually used " 105 | f"will not be updated" 106 | ) 107 | return 108 | 109 | Path(cache_dir).mkdir(parents=True, exist_ok=True) 110 | meta_collection, annotations_collection = self._collect_annotations() 111 | 112 | # when cache on disk, rank0 saves items to an h5 file 113 | logger.info(f"start to build data cache to: {Path(cache_dir)}") 114 | with h5py.File(Path(cache_dir) / "data.h5", "w") as file: 115 | dt = h5py.vlen_dtype(str) 116 | for i, annotations in enumerate(annotations_collection): 117 | serialized_ann = [json.dumps(_) for _ in annotations] 118 | h5_ann = file.create_dataset(f"ann{i}", (len(serialized_ann),), dtype=dt) 119 | h5_ann[:] = serialized_ann 120 | 121 | file.create_dataset("meta_collection", data=json.dumps(meta_collection)) 122 | with open(Path(cache_dir) / "ready", "w") as f: 123 | f.write("ready") 124 | logger.info(f"data cache built") 125 | 126 | @staticmethod 127 | def _get_cache_dir(config_path): 128 | config_identifier = config_path 129 | disallowed_chars = ["/", "\\", ".", "?", "!"] 130 | for _ in disallowed_chars: 131 | config_identifier = config_identifier.replace(_, "-") 132 | cache_dir = f"./xllmx_data_cache/{config_identifier}" 133 | return cache_dir 134 | 135 | @staticmethod 136 | def _load_annotations_from_cache(cache_dir): 137 | while not (Path(cache_dir) / "ready").exists(): 138 | # cache has not yet been completed by rank 0 139 | assert dist.get_rank() != 0 140 | sleep(1) 141 | cache_file = h5py.File(Path(cache_dir) / "data.h5", "r") 142 | meta_collection = json.loads(cache_file["meta_collection"].asstr()[()]) 143 | annotations_collection = [cache_file[f"ann{i}"] for i in range(len(meta_collection))] 144 | return meta_collection, annotations_collection 145 | 146 | def get_item_func(self, meta_idx, idx_in_meta): 147 | data_item = self.annotations_collection[meta_idx][idx_in_meta] 148 | if self.cache_on_disk: 149 | data_item = json.loads(data_item) 150 | else: 151 | data_item = copy.deepcopy(data_item) 152 | 153 | return self.item_processor.process_item(data_item, training_mode=True) 154 | 155 | def tie_index_to_meta(self, idx: int): 156 | # Initialize the starting index 157 | start_idx = 0 158 | 159 | # Iterate through the list of dictionaries 160 | for i, meta in enumerate(self.meta_collection): 161 | # Calculate the ending index for the current collection 162 | end_idx = start_idx + meta["len"] 163 | 164 | # Check if the given index falls within the current collection 165 | if start_idx <= idx < end_idx: 166 | # Calculate the new index within the current collection 167 | new_index = idx - start_idx 168 | return i, new_index 169 | 170 | # Update the starting index for the next collection 171 | start_idx = end_idx 172 | 173 | # If the index is out of range of all collections, raise an error 174 | raise IndexError("Index out of range") 175 | 176 | def __getitem__(self, index): 177 | meta_idx, idx_in_meta = self.tie_index_to_meta(index) 178 | 179 | try: 180 | return self.get_item_func(meta_idx, idx_in_meta) 181 | except Exception as e: 182 | logger.info( 183 | f"Item {index} errored, annotation:\n" 184 | f"{self.annotations_collection[meta_idx][idx_in_meta]}\n" 185 | f"Error:\n" 186 | f"{traceback.format_exc()}" 187 | ) 188 | if idx_in_meta != 0: 189 | return self[index - 1] 190 | else: 191 | return self[index + self.meta_collection[meta_idx]["len"] - 1] 192 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/data/item_processor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import copy 3 | import logging 4 | from typing import Any, Callable, Dict, List, Tuple, Union 5 | 6 | from xllmx.model.tokenizer import Tokenizer 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class LabelAllZeroError(Exception): 12 | def __init__(self, message=None): 13 | self.message = message 14 | 15 | def __str__(self): 16 | return f"LabelAllZeroError: {self.message}" 17 | 18 | 19 | class ItemProcessorBase(ABC): 20 | @abstractmethod 21 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]: 22 | raise NotImplementedError 23 | 24 | def predict_item_token_length(self, data_item: dict) -> int: 25 | """ 26 | estimate the token length of the data item for gathering items of similar lengths into a batch 27 | """ 28 | return 1 29 | 30 | 31 | class MMConvItemProcessor(ItemProcessorBase): 32 | def __init__( 33 | self, 34 | transform: Dict[str, Callable[[Any], Dict]], 35 | media_symbols: List[str], 36 | tokenizer: str | Tokenizer, 37 | conv_template, 38 | ): 39 | self.transform = transform 40 | logger.info(f"transform:\n{self.transform}") 41 | 42 | self.media_symbols = media_symbols 43 | logger.info(f"media_symbols:\n{self.media_symbols}") 44 | 45 | if isinstance(tokenizer, str): 46 | self.tokenizer = Tokenizer(model_path=tokenizer) 47 | else: 48 | self.tokenizer = copy.deepcopy(tokenizer) 49 | 50 | # todo should not already exist 51 | self.tokenizer.tokenizer.add_tokens(media_symbols) 52 | self.d_media_symbol2token = {} 53 | self.d_media_token2symbol = {} 54 | for media_symbol in media_symbols: 55 | tokenized_symbol = self.tokenizer.encode(media_symbol, bos=False, eos=False) 56 | assert len(tokenized_symbol) == 1 57 | self.d_media_symbol2token[media_symbol] = tokenized_symbol[0] 58 | self.d_media_token2symbol[tokenized_symbol[0]] = media_symbol 59 | 60 | # implicit_at_beginning means media without explict location specification are arranged right after bos token 61 | # if false, then these medias are arranged at the beginning of the first question 62 | self.implicit_at_beginning = False 63 | self.conv_template = conv_template 64 | 65 | def collect_and_process_media(self, data_item): 66 | """ 67 | this function receives a raw piece of data (e.g. read from `.json` data file), 68 | and returns d_media, containing the prepared media readily usable by model 69 | YOU MAY OVERRIDE THIS FUNCTION TO SUPPORT COMPLEX LOADING OF VARIOUS FORMS OF DATA 70 | """ 71 | d_media = {} 72 | for media_symbol in self.media_symbols: 73 | if media_symbol in data_item: 74 | l_media = data_item[media_symbol] # a list of media paths 75 | elif media_symbol.lstrip("<|").rstrip("|>") in data_item: 76 | l_media = data_item[media_symbol.lstrip("<|").rstrip("|>")] 77 | else: 78 | l_media = [] 79 | if not isinstance(l_media, list): # data with only one media, in format {"image": image_name, ...} 80 | l_media = [l_media] 81 | 82 | d_media[media_symbol] = [] 83 | for media in l_media: 84 | media = self.transform[media_symbol](media) 85 | assert isinstance(media, Dict) 86 | media["type"] = media_symbol 87 | d_media[media_symbol].append(media) 88 | 89 | return d_media 90 | 91 | def replace_media_token_with_media( 92 | self, tokens: List[int], labels: Union[List[int], None], d_media: Dict[str, List] 93 | ): 94 | d_media_counter = {key: 0 for key in d_media} 95 | for i, t in enumerate(tokens): 96 | if t in self.d_media_token2symbol: 97 | media_symbol = self.d_media_token2symbol[t] 98 | media = d_media[media_symbol][d_media_counter[media_symbol]] 99 | d_media_counter[media_symbol] += 1 100 | tokens[i] = media 101 | media["to_predict"] = labels[i] > 0 102 | 103 | assert all([d_media_counter[key] == len(d_media[key]) for key in d_media]) 104 | 105 | if labels is not None: 106 | return tokens, labels 107 | else: 108 | return tokens 109 | 110 | @staticmethod 111 | def insert_implicit_media_symbol_in_q1(conv_list: List[Dict], d_media: Dict): 112 | """ 113 | Add the media tokens to the beginning of the first instruction from 114 | human. This logic may be more reasonable. However, it is incompatible 115 | with old-version Accessory models, which are trained with image tokens 116 | inserted directly behind the first token (). 117 | :param conv_list: [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}, ...] 118 | :param d_media: a dict of media for all media types 119 | """ 120 | conv_list = copy.deepcopy(conv_list) 121 | 122 | for media_symbol, l_media in d_media.items(): 123 | media_symbol_count = "".join([_["value"] for _ in conv_list if _["value"] is not None]).count(media_symbol) 124 | if media_symbol_count > 0: 125 | assert media_symbol_count == len( 126 | l_media 127 | ), f"{media_symbol_count} {media_symbol} exists in text, but {len(l_media)} actual media are given" 128 | else: 129 | conv_list[0]["value"] = (media_symbol + " ") * len(l_media) + conv_list[0]["value"] 130 | 131 | return conv_list 132 | 133 | @staticmethod 134 | def insert_implicit_media_symbol_at_beginning(conv: str, d_media: Dict): 135 | """ 136 | Legacy versions of LLaMA2-Accessory handled media in a non-interleaved 137 | manner, where image tokens are inserted directly behind the first token, 138 | namely . To support interleaved media comprehension and generation, 139 | Accessory now supports the explicit specification of media occurrence, 140 | which is achieved by adding media symbols, e.g. , within the 141 | conversations. On the other hand, for media without explicit 142 | specification, this function realizes the legacy behavior to arrange 143 | them at the beginning of the conversation. 144 | :param conv: conversation 145 | :param d_media: a dict of media for all media types, for determining how 146 | many media tokens need to be inserted 147 | """ 148 | conv = copy.deepcopy(conv) 149 | 150 | for media_symbol, l_media in d_media.items(): 151 | media_symbol_count = conv.count(media_symbol) 152 | if media_symbol_count > 0: 153 | assert media_symbol_count == len( 154 | l_media 155 | ), f"{media_symbol_count} {media_symbol} exists in text, but {len(l_media)} actual media are given" 156 | else: 157 | conv = (media_symbol + " ") * len(l_media) + conv 158 | 159 | return conv 160 | 161 | def preprocess_item(self, data_item): 162 | return data_item 163 | 164 | def add_speaker_and_signal(self, source: List): 165 | """ 166 | Given source instruction and response pieces, return the text containing the complete conversation, 167 | and the list of values that the model should learn to predict during training 168 | :param source: [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}, ...] 169 | :return: `conversation`: string containing the complete conversation; 170 | `to_predict_list`: the list of values that the model should learn to predict during training 171 | """ 172 | conv = self.conv_template() 173 | 174 | for i, sentence in enumerate(source): 175 | from_str = sentence["from"] 176 | if i % 2 == 0: 177 | assert from_str.lower() in ["human"] 178 | role = conv.roles[0] 179 | elif i % 2 == 1: 180 | assert from_str.lower() in ["gpt", "assistant"] 181 | role = conv.roles[1] 182 | else: 183 | raise ValueError(f"unknown dialog role: {from_str.lower()}") 184 | 185 | value = sentence["value"] 186 | 187 | conv.append_message(role, value) 188 | 189 | processed = conv.process() 190 | conversation, pieces = processed["conv"], processed["pieces"] 191 | 192 | return conversation, pieces 193 | 194 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]: 195 | data_item = self.preprocess_item(data_item) 196 | 197 | d_media = self.collect_and_process_media(data_item) 198 | 199 | source = data_item["conversations"] 200 | 201 | # implicit_at_beginning means media without explict location specification are arranged right after bos token 202 | # if false, then these medias are arranged at the beginning of the first question 203 | if not self.implicit_at_beginning: 204 | source = self.insert_implicit_media_symbol_in_q1(source, d_media) 205 | 206 | conversation, pieces = self.add_speaker_and_signal(source) 207 | 208 | if self.implicit_at_beginning: 209 | conversation = self.insert_implicit_media_symbol_at_beginning(conversation, d_media) 210 | 211 | # dialog does not need eos 212 | tokens = self.tokenizer.encode(conversation, bos=True, eos=False) 213 | labels = [-100 for _ in tokens] 214 | 215 | # check special token num as expected 216 | for media_symbol, l_media in d_media.items(): 217 | media_token = self.d_media_symbol2token[media_symbol] 218 | media_token_count = tokens.count(media_token) 219 | assert media_token_count == len(l_media), ( 220 | f"{media_token_count} {media_token} (for {media_symbol}) exists in tokenized conversation, " 221 | f"but {len(l_media)} actual media are given" 222 | ) 223 | 224 | check_pos = 0 225 | for i, p in enumerate(pieces): 226 | if i == 0: 227 | tokenized_value = self.tokenizer.encode(p["data"], bos=True, eos=False) 228 | else: 229 | tokenized_value = self.tokenizer.encode_wo_prefix_space(p["data"]) 230 | 231 | assert ( 232 | tokens[check_pos : check_pos + len(tokenized_value)] == tokenized_value 233 | ), "inconsistent complete conversation and corresponding piece after tokenization" 234 | 235 | if p["predict"]: 236 | labels[check_pos : check_pos + len(tokenized_value)] = tokenized_value 237 | 238 | check_pos = check_pos + len(tokenized_value) 239 | 240 | if training_mode and all([_ <= 0 for _ in labels]): # nothing to predict 241 | raise LabelAllZeroError() 242 | 243 | # labels will be processed later by the model 244 | tokens, labels = self.replace_media_token_with_media(tokens, labels, d_media) 245 | 246 | assert len(tokens) == len(labels) 247 | 248 | if training_mode: 249 | return tokens, labels 250 | else: 251 | return tokens 252 | 253 | def predict_item_token_length(self, data_item: dict) -> int: 254 | """ 255 | estimate the length of each item 256 | """ 257 | 258 | if "conversations" in data_item: 259 | return sum([len(_["value"]) for _ in data_item["conversations"]]) 260 | else: 261 | return 1 262 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/data/sampler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | from typing import Iterator, List, Optional 4 | 5 | import numpy as np 6 | from torch.utils.data import Sampler 7 | 8 | from xllmx.data.dataset import FinetuneConversationDataset 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | # todo too slow to be used 14 | def mild_shuffle(items: List, shuffle_factor, engine: np.random.Generator): 15 | """ 16 | Perform a mild shuffle on the list of items. 17 | 18 | Args: 19 | engine: random engine 20 | items (list): The list of items to shuffle. 21 | shuffle_factor (float): max swap range is computed as len(item) * shuffle_factor. 22 | 23 | Returns: 24 | list: The mildly shuffled list. 25 | """ 26 | 27 | n = len(items) 28 | swap_range = int(shuffle_factor * n) 29 | shuffled_items = [None for _ in items] 30 | cache = list(range(swap_range)) 31 | for i in range(n): 32 | if i + swap_range < n: 33 | cache.append(i + swap_range) 34 | if len(cache) == 0 or cache[0] != i: # already swapped 35 | assert shuffled_items[i] is not None 36 | continue 37 | else: 38 | cache = cache[1:] 39 | if len(cache) == 0: 40 | shuffled_items[i] = items[i] 41 | else: 42 | cache_idx = engine.integers(low=0, high=len(cache)) 43 | j = cache[cache_idx] 44 | del cache[cache_idx] 45 | shuffled_items[i], shuffled_items[j] = items[j], items[i] 46 | 47 | return shuffled_items 48 | 49 | 50 | class FinetuneDistSampler(Sampler): 51 | def __init__( 52 | self, 53 | dataset: FinetuneConversationDataset, 54 | num_replicas: Optional[int] = None, 55 | rank: Optional[int] = None, 56 | shuffle: bool = True, 57 | seed: int = 0, 58 | batch_size=None, 59 | acc_grad=1, 60 | length_clustering=True, 61 | allow_mixed_task_among_acc=False, 62 | ): 63 | """ 64 | Distributed Sampler ensuring data in a batch are of the same type (e.g. text, image-text) 65 | :param dataset: 66 | :param num_replicas: 67 | :param rank: 68 | :param shuffle: 69 | :param seed: 70 | :param batch_size: 71 | :param acc_grad: 72 | :param length_clustering: 73 | :param allow_mixed_task_among_acc: 74 | """ 75 | # super().__init__() 76 | 77 | if num_replicas is None or rank is None or rank >= num_replicas or rank < 0: 78 | raise ValueError(f"Invalid num_replicas ({num_replicas}) or rank ({rank})") 79 | assert batch_size is not None 80 | 81 | self.dataset = dataset 82 | self.num_replicas = num_replicas 83 | self.rank = rank 84 | self.shuffle = shuffle 85 | self.seed = seed 86 | self.batch_size = batch_size 87 | self.acc_grad = acc_grad 88 | self.length_clustering = length_clustering 89 | self.allow_mixed_task_among_acc = allow_mixed_task_among_acc 90 | 91 | self.epoch = 0 92 | self.start_iter = 0 93 | 94 | global_bsz_acc = batch_size * num_replicas * acc_grad 95 | 96 | group_len = defaultdict(int) 97 | for i, meta in enumerate(dataset.meta_collection): 98 | group_len[meta["type"]] += int(meta["len"] * meta.get("ratio", 1.0)) 99 | 100 | group_len = {key: val // global_bsz_acc * global_bsz_acc for key, val in group_len.items()} 101 | 102 | self.total_size = sum(list(group_len.values())) 103 | assert self.total_size % num_replicas == 0 104 | self.num_samples = self.total_size // num_replicas 105 | 106 | def __iter__(self) -> Iterator: 107 | global_batch_size = self.batch_size * self.num_replicas 108 | global_bsz_acc = self.batch_size * self.num_replicas * self.acc_grad 109 | rng = np.random.default_rng(self.seed + self.epoch) 110 | 111 | group_indices_and_len = defaultdict(list) 112 | 113 | # Initialize the starting index 114 | start_idx = 0 115 | 116 | # Iterate through the list of dictionaries 117 | for i, meta in enumerate(self.dataset.meta_collection): 118 | # Calculate the ending index for the current collection 119 | end_idx = start_idx + meta["len"] 120 | indices = list(range(start_idx, end_idx)) 121 | assert len(indices) == len(meta["item_len_list"]) 122 | indices_and_len = [[idx, length] for idx, length in zip(indices, meta["item_len_list"])] 123 | if meta.get("ratio", 1.0) != 1.0: 124 | indices_and_len = list(rng.choice(indices_and_len, int(meta["len"] * meta["ratio"]), replace=False)) 125 | logger.info(f"meta{i}: sample (ratio = {meta['ratio']}) {len(indices_and_len)} items") 126 | group_indices_and_len[meta["type"]].extend(indices_and_len) 127 | 128 | # Update the starting index for the next collection 129 | start_idx = end_idx 130 | 131 | for group_name, indices_and_len in group_indices_and_len.items(): 132 | group_indices_and_len[group_name] = indices_and_len[ 133 | : len(indices_and_len) // global_bsz_acc * global_bsz_acc 134 | ] 135 | 136 | if self.shuffle: 137 | group_indices = {} 138 | if self.length_clustering: 139 | for group_name, indices_and_len in group_indices_and_len.items(): 140 | indices_and_len.sort(key=lambda x: x[1]) 141 | group_indices[group_name] = [_[0] for _ in indices_and_len] 142 | 143 | # option1: shuffle among neighboring items 144 | for group_name, indices in group_indices.items(): 145 | result = [] 146 | for pos in range(0, len(indices), global_batch_size * 500): 147 | sublist = indices[pos : pos + global_batch_size * 500] 148 | rng.shuffle(sublist) 149 | result.extend(sublist) 150 | group_indices[group_name] = result 151 | # option2: mild shuffle 152 | # group_indices[group_name] = mild_shuffle(indices, 0.1, rng) 153 | # option3: do nothing 154 | # pass 155 | else: 156 | for group_name, indices_and_len in group_indices_and_len.items(): 157 | rng.shuffle(indices_and_len) 158 | group_indices[group_name] = [_[0] for _ in indices_and_len] 159 | 160 | del group_indices_and_len 161 | 162 | if self.allow_mixed_task_among_acc: 163 | global_batched_indices = [ 164 | indices[i : i + global_batch_size] 165 | for group_name, indices in group_indices.items() 166 | for i in range(0, len(indices), global_batch_size) 167 | ] 168 | else: 169 | global_batched_indices = [] 170 | for group_name, indices in group_indices.items(): 171 | group_batched_indices = [ 172 | indices[i : i + global_batch_size] for i in range(0, len(indices), global_batch_size) 173 | ] 174 | rng.shuffle(group_batched_indices) 175 | group_batched_indices = [ 176 | sum(group_batched_indices[i : i + self.acc_grad], start=[]) 177 | for i in range(0, len(group_batched_indices), self.acc_grad) 178 | ] 179 | global_batched_indices.extend(group_batched_indices) 180 | rng.shuffle(global_batched_indices) 181 | indices = [_ for batch_indices in global_batched_indices for _ in batch_indices] 182 | else: 183 | raise NotImplementedError() 184 | 185 | assert len(indices) == self.total_size 186 | 187 | own_indices = [] 188 | for start_pos in range(self.rank * self.batch_size, len(indices), self.num_replicas * self.batch_size): 189 | own_indices += indices[start_pos : start_pos + self.batch_size] 190 | # subsample 191 | assert len(own_indices) == self.num_samples 192 | 193 | if self.start_iter * self.batch_size > len(own_indices): 194 | own_indices = [] 195 | else: 196 | own_indices = own_indices[self.start_iter * self.batch_size :] 197 | 198 | return iter(own_indices) 199 | 200 | def __len__(self) -> int: 201 | return self.num_samples 202 | 203 | def set_epoch(self, epoch: int, start_iter: int = 0) -> None: 204 | r""" 205 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 206 | use a different random ordering for each epoch. Otherwise, the next iteration of this 207 | sampler will yield the same ordering. 208 | 209 | Args: 210 | epoch (int): Epoch number. 211 | start_iter (int): start iter number. 212 | """ 213 | self.epoch = epoch 214 | self.start_iter = start_iter 215 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/xllmx/model/__init__.py -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/model/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/model/tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | from typing import List, Optional 5 | 6 | from sentencepiece import SentencePieceProcessor 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ["Tokenizer", "probe_tokenizer_path_from_pretrained"] 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Tokenizer: 16 | def __init__(self, model_path: str): 17 | """ 18 | Create a tokenizer, with inner implementation either spm or HF transformers tokenzier 19 | :param model_path: 20 | - when using spm tokenizer, should be path to a sentencepiece model with suffix `.model` 21 | - when using huggingface transformers tokenizer, should be an HF model repo or a local directory, 22 | containing tokenizer.json and tokenizer_config.json. 23 | """ 24 | if model_path.endswith(".model"): # spm tokenizer 25 | self.tokenizer_type = "spm" 26 | # reload tokenizer 27 | assert os.path.isfile(model_path), model_path 28 | self.tokenizer = SentencePieceProcessor(model_file=model_path) 29 | logger.info(f"Reloaded SentencePiece model from {model_path}") 30 | 31 | # BOS / EOS token IDs 32 | self.bos_id: int = self.tokenizer.bos_id() 33 | self.eos_id: int = self.tokenizer.eos_id() 34 | assert self.tokenizer.vocab_size() == self.tokenizer.get_piece_size() 35 | else: 36 | self.tokenizer_type = "transformers" 37 | print(model_path) 38 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 39 | logger.info(f"load HF transformers tokenizer from {model_path}") 40 | # BOS / EOS token IDs 41 | self.bos_id: int = self.tokenizer.bos_token_id 42 | if self.bos_id is None: 43 | self.bos_id = self.tokenizer.eos_token_id 44 | self.eos_id: int = self.tokenizer.eos_token_id 45 | assert self.eos_id is not None 46 | 47 | self._probe_tokenizer_style() 48 | 49 | logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") 50 | 51 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 52 | assert type(s) is str 53 | if self.tokenizer_type == "transformers": 54 | t = self.tokenizer.encode(s, truncation=False, add_special_tokens=False) 55 | else: 56 | t = self.tokenizer.encode(s) 57 | if bos: 58 | t = [self.bos_id] + t 59 | if eos: 60 | t = t + [self.eos_id] 61 | return t 62 | 63 | def encode_segment(self, s: str): 64 | s = s.lstrip(" ") 65 | if self.need_space_before_segment: 66 | return self.encode(" " + s, bos=False, eos=False) 67 | else: 68 | return self.encode(s, bos=False, eos=False) 69 | 70 | def encode_wo_prefix_space(self, s: str): 71 | if self.need_space_before_segment: 72 | return self.encode(s, bos=False, eos=False) 73 | else: 74 | # prefix chars that, when preceding other strings without seperator in between, 75 | # are relatively more likely to be tokenized independently rather than getting 76 | # merged into the following strings. 77 | l_prefix = ["@", "\n", "\\", "=", ">", "`"] 78 | for prefix in l_prefix: 79 | prefix_tokens = self.encode(prefix, bos=False, eos=False) 80 | cat_tokens = self.encode(prefix + s, bos=False, eos=False) 81 | if cat_tokens[: len(prefix_tokens)] == prefix_tokens: 82 | return cat_tokens[len(prefix_tokens) :] 83 | 84 | raise NotImplementedError( 85 | f"All prefixes are merged into {s} during tokenization," 86 | f"This is wierd behavior, please open an issue to report this problem", 87 | ) 88 | 89 | def _probe_tokenizer_style(self): 90 | """ 91 | Given a sentence, e.g. "Hi my darling", some tokenizers (e.g. LLaMA's) will pose the following behavior: 92 | >>> # leading characters will be treated as if there were an " " in the beginning 93 | >>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode("my darling") 94 | >>> # leading space " " is redundant and should not be added 95 | >>> tokenizer.encode("Hi my darling") != tokenizer.encode("Hi") + tokenizer.encode(" my darling") 96 | However, some others (e.g. InternLM's) will behave differently: 97 | >>> # leading space " " has to be explicitly added 98 | >>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode(" my darling") 99 | Knowing which style the tokenizer takes is necessary when tokenzing a segment cut from the complete 100 | text, so that the result is the same as the corresponding part in the tokenized original text. 101 | """ 102 | sentence1 = self.encode("Hi my darling", bos=False, eos=False) 103 | sentence2 = self.encode("my darling", bos=False, eos=False) 104 | if sentence1[-len(sentence2) :] == sentence2: 105 | self.need_space_before_segment = False 106 | else: 107 | sentence3 = self.encode(" my darling", bos=False, eos=False) 108 | assert sentence1[-len(sentence3) :] == sentence3 109 | self.need_space_before_segment = True 110 | 111 | def decode(self, t: List[int]) -> str: 112 | return self.tokenizer.decode(t) 113 | 114 | def save(self, save_dir: str): 115 | if self.tokenizer_type == "transformers": 116 | self.tokenizer.save_pretrained(save_dir) 117 | else: 118 | with open(Path(save_dir) / "tokenizer.model", "wb") as f: 119 | f.write(self.tokenizer.serialized_model_proto()) 120 | 121 | @property 122 | def n_words(self): 123 | if self.tokenizer_type == "spm": 124 | return self.tokenizer.vocab_size() 125 | elif self.tokenizer_type == "transformers": 126 | return len(self.tokenizer) 127 | else: 128 | raise RuntimeError 129 | 130 | 131 | def probe_tokenizer_path_from_pretrained(pretrained_path: str): 132 | tokenizer_path = None 133 | 134 | # try find spm-style tokenizer 135 | logger.info(f"trying to find sentencepiece-style tokenizer at {Path(pretrained_path) / 'tokenizer.model'}") 136 | if (Path(pretrained_path) / "tokenizer.model").exists(): 137 | logger.info(f"Found {Path(pretrained_path) / 'tokenizer.model'}, use it.") 138 | tokenizer_path = str(Path(pretrained_path) / "tokenizer.model") 139 | else: 140 | logger.info("Not Found") 141 | 142 | # then try huggingface style 143 | if tokenizer_path is None: 144 | logger.info( 145 | f"trying to find huggingface-style tokenizer at " 146 | f"{Path(pretrained_path) / '(tokenizer.json, tokenizer_config.json)'}" 147 | ) 148 | if (Path(pretrained_path) / "tokenizer.json").exists() and ( 149 | Path(pretrained_path) / "tokenizer_config.json" 150 | ).exists(): 151 | logger.info(f"Found {Path(pretrained_path) / '(tokenizer.json, tokenizer_config.json)'}, use them.") 152 | tokenizer_path = pretrained_path 153 | else: 154 | logger.info("Not Found") 155 | if tokenizer_path is None: 156 | logger.info("No usable tokenizer found") 157 | return tokenizer_path 158 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/solvers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/Explanatory_Instructions_Tuning/xllmx/solvers/__init__.py -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/solvers/finetune/__init__.py: -------------------------------------------------------------------------------- 1 | from .finetune import FinetuneSolverBase 2 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/util/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ckpt, dist 2 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/util/ckpt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | from typing import Dict, Optional 6 | 7 | import torch 8 | from torch import distributed as dist 9 | from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, StateDictType 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def split_ckpt_str_into_epoch_iter(ckpt_str: str): 15 | # divide ckpt directory names into epoch and iter parts 16 | parts = ckpt_str.split("-") 17 | epoch = int(parts[0].replace("epoch", "")) 18 | if len(parts) == 2: 19 | iter_part = int(parts[1].replace("iter", "")) 20 | else: 21 | iter_part = None 22 | return epoch, iter_part 23 | 24 | 25 | def remove_early_ckpts(out_dir, max_keep=2): 26 | 27 | if max_keep <= 0: 28 | return 29 | 30 | def ckpt_sort_key(s): 31 | # divide ckpt directory names into epoch and iter parts 32 | epoch, iteration = split_ckpt_str_into_epoch_iter(s) 33 | if iteration is None: 34 | iteration = float("inf") 35 | return epoch, iteration 36 | 37 | existing_checkpoints = [_ for _ in os.listdir(out_dir) if "epoch" in _] 38 | existing_checkpoints = sorted(existing_checkpoints, key=ckpt_sort_key, reverse=True) 39 | 40 | for dir_to_remove in existing_checkpoints[max_keep:]: 41 | dir_to_remove = os.path.join(out_dir, dir_to_remove) 42 | shutil.rmtree(dir_to_remove) 43 | logger.info(f"Deleted {dir_to_remove}") 44 | 45 | 46 | def save( 47 | output_dir, 48 | is_main_process, 49 | model: FSDP, 50 | optimizer: Optional[torch.optim.Optimizer] = None, 51 | tokenizer=None, 52 | args=None, 53 | epoch=None, 54 | iteration=None, 55 | additional_rank_common: Optional[Dict] = None, 56 | additional_rank_specific: Optional[Dict] = None, 57 | max_keep=2, 58 | ): 59 | save_name = f"epoch{epoch}" 60 | if iteration is not None: 61 | save_name += f"-iter{iteration}" 62 | save_dir = os.path.join(output_dir, save_name) 63 | 64 | os.makedirs(save_dir, exist_ok=True) 65 | 66 | # save model 67 | with FSDP.state_dict_type( 68 | model, 69 | StateDictType.FULL_STATE_DICT, 70 | FullStateDictConfig(rank0_only=True, offload_to_cpu=True), 71 | ): 72 | # run saving in separate functions to save memory 73 | def _save_model(): 74 | save_dtype = { 75 | "fp16": torch.float16, 76 | "bf16": torch.bfloat16, 77 | "tf32": torch.float, 78 | }[ 79 | args.precision 80 | ] # todo make saving precision optional 81 | if getattr(args, "only_save_trainable", False): 82 | model_trainable_params = model.get_trainable_params() 83 | model_trainable_params = [ 84 | ".".join([_ for _ in key.split(".") if not _.startswith("_")]) 85 | for key in model_trainable_params.keys() 86 | ] 87 | consolidated_model_state_dict = { 88 | key: val.to(save_dtype) for key, val in model.state_dict().items() if key in model_trainable_params 89 | } 90 | else: 91 | consolidated_model_state_dict = {key: val.to(save_dtype) for key, val in model.state_dict().items()} 92 | 93 | if is_main_process: 94 | model.save_pretrained(save_dir, state_dict=consolidated_model_state_dict) 95 | 96 | _save_model() 97 | logger.info("model saved") 98 | 99 | # save optimizer 100 | if optimizer is not None: 101 | with FSDP.state_dict_type( 102 | model, 103 | StateDictType.LOCAL_STATE_DICT, 104 | ): 105 | opt_path = os.path.join( 106 | save_dir, 107 | f"optimizer.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth", 108 | ) 109 | torch.save(optimizer.state_dict(), opt_path) 110 | logger.info("optimizer saved") 111 | else: 112 | logger.info("optimizer is None, skip saving") 113 | 114 | if additional_rank_specific is not None: 115 | torch.save( 116 | additional_rank_specific, 117 | os.path.join(save_dir, f"additional.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth"), 118 | ) 119 | logger.info(f"additional_rank_specific {list(additional_rank_specific.keys())} saved") 120 | 121 | if not is_main_process: 122 | dist.barrier() 123 | return 124 | 125 | # =========The followings are for main process only========= 126 | if tokenizer is not None: 127 | tokenizer.save(save_dir) 128 | logger.info("tokenizer saved") 129 | else: 130 | logger.info("tokenizer is None, skip saving") 131 | 132 | if args is not None: 133 | with open(os.path.join(save_dir, "args.json"), "w") as f: 134 | json.dump(vars(args), f, indent=2) 135 | logger.info("args saved") 136 | else: 137 | logger.info("args is None, skip saving") 138 | 139 | if additional_rank_common is not None: 140 | torch.save(additional_rank_common, os.path.join(save_dir, "additional_rank_common.pth")) 141 | logger.info(f"additional_resources {list(additional_rank_common.keys())} saved") 142 | 143 | remove_early_ckpts(output_dir, max_keep=max_keep) 144 | 145 | dist.barrier() 146 | return 147 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/util/dist.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import socket 5 | import subprocess 6 | import time 7 | from types import SimpleNamespace 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | from xllmx.util.misc import random_seed 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def find_free_port(start_port: int, end_port: int): 18 | """ 19 | Find a free port within the specified range. 20 | """ 21 | for port in range(start_port, end_port): 22 | try: 23 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 24 | s.bind(("", port)) # Try to bind to the port 25 | s.close() # Close the socket if successful 26 | return port 27 | except OSError as e: 28 | # print(f"Port {port} is in use, trying next port.") 29 | continue 30 | raise RuntimeError(f"No free ports found in range {start_port}-{end_port}") 31 | 32 | 33 | def init_distributed_mode(args=SimpleNamespace()): 34 | random_seed(getattr(args, "seed", 0)) 35 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ: 36 | args.world_size = int(os.environ["WORLD_SIZE"]) 37 | args.rank = int(os.environ["RANK"]) 38 | args.gpu = int(os.environ["LOCAL_RANK"]) 39 | args.local_rank = args.gpu 40 | args.dist_url = "env://" 41 | elif "SLURM_PROCID" in os.environ: 42 | os.environ["MASTER_PORT"] = "8966" 43 | while "MASTER_ADDR" not in os.environ or len(os.environ["MASTER_ADDR"].strip()) == 0: 44 | os.environ["MASTER_ADDR"] = ( 45 | subprocess.check_output( 46 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"], 47 | shell=True, 48 | ) 49 | .decode() 50 | .strip() 51 | ) 52 | time.sleep(1) 53 | print(os.environ["MASTER_ADDR"]) 54 | args.world_size = int(os.environ["SLURM_NPROCS"]) 55 | args.rank = int(os.environ["SLURM_PROCID"]) 56 | args.gpu = args.rank % torch.cuda.device_count() 57 | args.local_rank = args.gpu 58 | args.dist_url = "env://" 59 | os.environ["LOCAL_RANK"] = str(args.gpu) 60 | os.environ["WORLD_SIZE"] = str(args.world_size) 61 | os.environ["RANK"] = str(args.rank) 62 | else: 63 | os.environ["MASTER_ADDR"] = "127.0.0.1" 64 | os.environ["MASTER_PORT"] = str(find_free_port(9000, 10000)) 65 | os.environ["RANK"] = "0" 66 | os.environ["LOCAL_RANK"] = "0" 67 | os.environ["WORLD_SIZE"] = "1" 68 | args.rank = 0 69 | args.gpu = args.local_rank = 0 70 | args.world_size = 1 71 | args.dist_url = "env://" 72 | 73 | args.distributed = True 74 | 75 | torch.cuda.set_device(args.gpu) 76 | args.dist_backend = "nccl" 77 | print("| distributed init (rank {}): {}, gpu {}".format(args.rank, args.dist_url, args.gpu), flush=True) 78 | torch.distributed.init_process_group( 79 | backend=args.dist_backend, 80 | init_method=args.dist_url, 81 | world_size=args.world_size, 82 | rank=args.rank, 83 | timeout=datetime.timedelta(seconds=2 * 60 * 60), 84 | ) 85 | torch.distributed.barrier() 86 | 87 | 88 | def all_reduce_mean(x, group=None): 89 | world_size = dist.get_world_size(group=group) 90 | if world_size > 1: 91 | if isinstance(x, torch.Tensor): 92 | x_reduce = x.clone().cuda() 93 | else: 94 | x_reduce = torch.tensor(x).cuda() 95 | dist.all_reduce(x_reduce, group=group) 96 | x_reduce /= world_size 97 | return x_reduce.item() 98 | else: 99 | return x 100 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, it, args): 5 | """Decay the learning rate with half-cycle cosine after warmup""" 6 | if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps 7 | lr = args.lr * it / args.warmup_iters 8 | elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate 9 | lr = args.min_lr 10 | else: # 3) in between, use cosine decay down to min learning rate 11 | decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters) 12 | assert 0 <= decay_ratio <= 1 13 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 14 | lr = args.min_lr + (args.lr - args.min_lr) * coeff 15 | 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | 23 | 24 | def adjust_learning_rate_epoch(optimizer, epoch, args): 25 | """Decay the learning rate with half-cycle cosine after warmup""" 26 | if epoch < args.warmup_epochs: 27 | lr = args.lr * epoch / args.warmup_epochs 28 | else: 29 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 30 | 1.0 + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)) 31 | ) 32 | for param_group in optimizer.param_groups: 33 | if "lr_scale" in param_group: 34 | param_group["lr"] = lr * param_group["lr_scale"] 35 | else: 36 | param_group["lr"] = lr 37 | return lr 38 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/util/misc.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, deque 2 | import datetime 3 | import logging 4 | import random 5 | import time 6 | 7 | from fairscale.nn.model_parallel import initialize as fs_init 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def random_seed(seed=0): 16 | random.seed(seed) 17 | torch.random.manual_seed(seed) 18 | np.random.seed(seed) 19 | 20 | 21 | class SmoothedValue(object): 22 | """Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window_size=1000, fmt=None): 27 | if fmt is None: 28 | fmt = "{avg:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window_size) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value 75 | ) 76 | 77 | 78 | class MetricLogger(object): 79 | def __init__(self, delimiter="\t"): 80 | self.meters = defaultdict(SmoothedValue) 81 | self.delimiter = delimiter 82 | 83 | def update(self, **kwargs): 84 | for k, v in kwargs.items(): 85 | if v is None: 86 | continue 87 | elif isinstance(v, (torch.Tensor, float, int)): 88 | self.meters[k].update(v.item() if isinstance(v, torch.Tensor) else v) 89 | elif isinstance(v, list): 90 | for i, sub_v in enumerate(v): 91 | self.meters[f"{k}_{i}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v) 92 | elif isinstance(v, dict): 93 | for sub_key, sub_v in v.items(): 94 | self.meters[f"{k}_{sub_key}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v) 95 | else: 96 | raise TypeError(f"Unsupported type {type(v)} for metric {k}") 97 | 98 | def __str__(self): 99 | loss_str = [] 100 | for name, meter in self.meters.items(): 101 | loss_str.append("{}: {}".format(name, str(meter))) 102 | return self.delimiter.join(loss_str) 103 | 104 | def synchronize_between_processes(self): 105 | for meter in self.meters.values(): 106 | meter.synchronize_between_processes() 107 | 108 | def add_meter(self, name, meter): 109 | self.meters[name] = meter 110 | 111 | def log_every(self, iterable, print_freq, header=None, start_iter=0, samples_per_iter=None): 112 | i = start_iter 113 | if not header: 114 | header = "" 115 | start_time = time.time() 116 | end = time.time() 117 | iter_time = SmoothedValue(fmt="{avg:.4f}") 118 | data_time = SmoothedValue(fmt="{avg:.4f}") 119 | log_msg = [header, "[{0" + "}/{1}]", "{meters}", "time: {time}", "data: {data}"] 120 | if samples_per_iter is not None: 121 | log_msg.append("samples/sec: {samples_per_sec:.2f}") 122 | if torch.cuda.is_available(): 123 | log_msg.append("max mem: {memory:.0f}") 124 | log_msg = self.delimiter.join(log_msg) 125 | MB = 1024.0 * 1024.0 126 | for obj in iterable: 127 | data_time.update(time.time() - end) 128 | yield obj 129 | iter_time.update(time.time() - end) 130 | if i % print_freq == 0: 131 | try: 132 | total_len = len(iterable) 133 | except: 134 | total_len = "unknown" 135 | 136 | msg_kwargs = { 137 | "meters": str(self), 138 | "time": str(iter_time), 139 | "data": str(data_time), 140 | } 141 | if samples_per_iter is not None: 142 | msg_kwargs["samples_per_sec"] = samples_per_iter / iter_time.avg 143 | if torch.cuda.is_available(): 144 | msg_kwargs["memory"] = torch.cuda.max_memory_allocated() / MB 145 | 146 | logger.info(log_msg.format(i, total_len, **msg_kwargs)) 147 | i += 1 148 | end = time.time() 149 | total_time = time.time() - start_time 150 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 151 | logger.info("{} Total time: {}".format(header, total_time_str)) 152 | 153 | 154 | def add_weight_decay(model, lr, weight_decay=1e-5): 155 | decay = [] 156 | no_decay = [] 157 | for name, param in model.named_parameters(): 158 | if not param.requires_grad: 159 | continue # frozen weights 160 | if name.endswith(".bias") or name.endswith("norm.weight"): 161 | no_decay.append(param) 162 | else: 163 | decay.append(param) 164 | return [ 165 | {"params": no_decay, "lr": lr, "weight_decay": weight_decay}, 166 | {"params": decay, "lr": lr, "weight_decay": weight_decay}, 167 | ] 168 | 169 | 170 | def broadcast_nonmp_parameters(model): 171 | if fs_init.get_model_parallel_world_size() == 1: 172 | return 173 | logger.info("starting broadcast non-model-parallel parameters within model parallel group") 174 | memo = set() 175 | modules = model.named_modules(prefix="", remove_duplicate=True) 176 | for module_prefix, module in modules: 177 | members = dict(module._parameters.items()) 178 | for k, v in members.items(): 179 | name = module_prefix + ("." if module_prefix else "") + k 180 | if v is None or v in memo: 181 | continue 182 | if getattr(v, "model_parallel", False): 183 | logger.info(f"ignore: {name}") 184 | continue 185 | memo.add(v) 186 | dist.broadcast(v, src=fs_init.get_model_parallel_src_rank(), group=fs_init.get_model_parallel_group()) 187 | logger.info("braodcast done") 188 | 189 | 190 | def mark_mp_params(model: torch.nn.Module): 191 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear 192 | 193 | for m in model.modules(): 194 | if isinstance(m, ColumnParallelLinear): 195 | m.weight.model_parallel = True 196 | if m.bias is not None: 197 | m.bias.model_parallel = True 198 | 199 | if isinstance(m, RowParallelLinear): 200 | m.weight.model_parallel = True 201 | 202 | if isinstance(m, ParallelEmbedding): 203 | m.weight.model_parallel = True 204 | 205 | 206 | def print_param_status(model: torch.nn.Module) -> None: 207 | require_grad_set = [] 208 | no_grad_set = [] 209 | for name, param in model.named_parameters(): 210 | if param.requires_grad: 211 | require_grad_set.append((name, param)) 212 | else: 213 | no_grad_set.append((name, param)) 214 | 215 | logger.info("Params that require gradient:\n") 216 | for name, param in require_grad_set: 217 | model_parallel = getattr(param, "model_parallel", False) 218 | logger.info( 219 | f"Param {name}: requires_grad {param.requires_grad}, local_size {param.shape}, model_parallel {model_parallel}, dtype {param.dtype}" 220 | ) 221 | 222 | logger.info("\nParams that do not require gradient:\n") 223 | for name, param in no_grad_set: 224 | model_parallel = getattr(param, "model_parallel", False) 225 | logger.info( 226 | f"Param {name}: requires_grad {param.requires_grad}, local_size {param.shape}, model_parallel {model_parallel}, dtype {param.dtype}" 227 | ) 228 | -------------------------------------------------------------------------------- /Explanatory_Instructions_Tuning/xllmx/util/tensor_type.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def promote_param_to_fp32(param: nn.Parameter) -> None: 6 | if param.is_floating_point() and torch.finfo(param.dtype).bits < 32: 7 | param.data = param.data.float() 8 | if param.is_complex() and torch.finfo(param.dtype).bits < 32: 9 | param.data = param.data.to(torch.complex64) 10 | -------------------------------------------------------------------------------- /assets/0016_0.8_0.08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0016_0.8_0.08.jpg -------------------------------------------------------------------------------- /assets/0016_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0016_gt.jpg -------------------------------------------------------------------------------- /assets/0016_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0016_output.jpg -------------------------------------------------------------------------------- /assets/0017_0.9_0.08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0017_0.9_0.08.jpg -------------------------------------------------------------------------------- /assets/0017_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0017_gt.jpg -------------------------------------------------------------------------------- /assets/0017_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0017_output.jpg -------------------------------------------------------------------------------- /assets/0018_0.9_0.2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0018_0.9_0.2.jpg -------------------------------------------------------------------------------- /assets/0018_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0018_gt.jpg -------------------------------------------------------------------------------- /assets/0018_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/0018_output.jpg -------------------------------------------------------------------------------- /assets/canny_edge_source_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/canny_edge_source_1.png -------------------------------------------------------------------------------- /assets/deblur_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deblur_input.jpg -------------------------------------------------------------------------------- /assets/deblur_output_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deblur_output_1.jpg -------------------------------------------------------------------------------- /assets/deblur_output_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deblur_output_2.jpg -------------------------------------------------------------------------------- /assets/deblur_output_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deblur_output_3.jpg -------------------------------------------------------------------------------- /assets/deblur_output_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deblur_output_4.jpg -------------------------------------------------------------------------------- /assets/dehazing_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/dehazing_input.jpg -------------------------------------------------------------------------------- /assets/dehazing_output_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/dehazing_output_1.jpg -------------------------------------------------------------------------------- /assets/dehazing_output_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/dehazing_output_2.jpg -------------------------------------------------------------------------------- /assets/dehazing_output_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/dehazing_output_3.jpg -------------------------------------------------------------------------------- /assets/dehazing_output_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/dehazing_output_4.jpg -------------------------------------------------------------------------------- /assets/depth_gt_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/depth_gt_15.png -------------------------------------------------------------------------------- /assets/depth_gt_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/depth_gt_18.png -------------------------------------------------------------------------------- /assets/depth_gt_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/depth_gt_21.png -------------------------------------------------------------------------------- /assets/depth_output_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/depth_output_15.jpg -------------------------------------------------------------------------------- /assets/depth_output_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/depth_output_18.jpg -------------------------------------------------------------------------------- /assets/depth_output_21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/depth_output_21.jpg -------------------------------------------------------------------------------- /assets/deraining_gt_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_gt_5.jpg -------------------------------------------------------------------------------- /assets/deraining_gt_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_gt_6.jpg -------------------------------------------------------------------------------- /assets/deraining_gt_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_gt_9.jpg -------------------------------------------------------------------------------- /assets/deraining_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_input.jpg -------------------------------------------------------------------------------- /assets/deraining_output_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_output_1.jpg -------------------------------------------------------------------------------- /assets/deraining_output_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_output_2.jpg -------------------------------------------------------------------------------- /assets/deraining_output_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_output_3.jpg -------------------------------------------------------------------------------- /assets/deraining_output_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_output_4.jpg -------------------------------------------------------------------------------- /assets/deraining_output_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_output_5.jpg -------------------------------------------------------------------------------- /assets/deraining_output_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_output_6.jpg -------------------------------------------------------------------------------- /assets/deraining_output_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/deraining_output_9.jpg -------------------------------------------------------------------------------- /assets/desnow_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/desnow_input.jpg -------------------------------------------------------------------------------- /assets/desnow_output_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/desnow_output_1.jpg -------------------------------------------------------------------------------- /assets/desnow_output_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/desnow_output_2.jpg -------------------------------------------------------------------------------- /assets/desnow_output_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/desnow_output_3.jpg -------------------------------------------------------------------------------- /assets/desnow_output_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/desnow_output_4.jpg -------------------------------------------------------------------------------- /assets/hed_gt_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_gt_10.png -------------------------------------------------------------------------------- /assets/hed_gt_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_gt_14.png -------------------------------------------------------------------------------- /assets/hed_gt_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_gt_6.png -------------------------------------------------------------------------------- /assets/hed_input_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_input_10.jpg -------------------------------------------------------------------------------- /assets/hed_input_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_input_14.jpg -------------------------------------------------------------------------------- /assets/hed_input_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_input_6.png -------------------------------------------------------------------------------- /assets/hed_output_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_output_10.png -------------------------------------------------------------------------------- /assets/hed_output_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_output_14.png -------------------------------------------------------------------------------- /assets/hed_output_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/hed_output_6.png -------------------------------------------------------------------------------- /assets/norain-5x2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/norain-5x2.jpg -------------------------------------------------------------------------------- /assets/norain-6x2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/norain-6x2.jpg -------------------------------------------------------------------------------- /assets/norain-9x2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/norain-9x2.jpg -------------------------------------------------------------------------------- /assets/rgb_00015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/rgb_00015.jpg -------------------------------------------------------------------------------- /assets/rgb_00018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/rgb_00018.jpg -------------------------------------------------------------------------------- /assets/rgb_00021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/rgb_00021.jpg -------------------------------------------------------------------------------- /assets/seg_gt_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_gt_1.jpg -------------------------------------------------------------------------------- /assets/seg_gt_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_gt_2.jpg -------------------------------------------------------------------------------- /assets/seg_gt_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_gt_3.jpg -------------------------------------------------------------------------------- /assets/seg_input_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_input_1.jpg -------------------------------------------------------------------------------- /assets/seg_input_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_input_2.jpg -------------------------------------------------------------------------------- /assets/seg_input_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_input_3.jpg -------------------------------------------------------------------------------- /assets/seg_output_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_output_1.jpg -------------------------------------------------------------------------------- /assets/seg_output_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_output_2.jpg -------------------------------------------------------------------------------- /assets/seg_output_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/seg_output_3.jpg -------------------------------------------------------------------------------- /assets/surface_gt_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/surface_gt_15.jpg -------------------------------------------------------------------------------- /assets/surface_gt_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/surface_gt_18.jpg -------------------------------------------------------------------------------- /assets/surface_gt_21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/surface_gt_21.jpg -------------------------------------------------------------------------------- /assets/surface_output_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/surface_output_15.jpg -------------------------------------------------------------------------------- /assets/surface_output_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/surface_output_18.jpg -------------------------------------------------------------------------------- /assets/surface_output_21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/surface_output_21.jpg -------------------------------------------------------------------------------- /assets/zero_shot_canny_edge_source_1_out_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zero_shot_canny_edge_source_1_out_1.jpg -------------------------------------------------------------------------------- /assets/zero_shot_canny_edge_source_1_out_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zero_shot_canny_edge_source_1_out_2.jpg -------------------------------------------------------------------------------- /assets/zero_shot_canny_edge_source_1_out_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zero_shot_canny_edge_source_1_out_3.jpg -------------------------------------------------------------------------------- /assets/zero_shot_canny_edge_source_1_out_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zero_shot_canny_edge_source_1_out_4.jpg -------------------------------------------------------------------------------- /assets/zs_low_light_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zs_low_light_1.jpg -------------------------------------------------------------------------------- /assets/zs_low_light_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zs_low_light_2.jpg -------------------------------------------------------------------------------- /assets/zs_low_light_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zs_low_light_3.jpg -------------------------------------------------------------------------------- /assets/zs_low_light_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zs_low_light_4.jpg -------------------------------------------------------------------------------- /assets/zs_low_light_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/assets/zs_low_light_input.jpg -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aassxun/Understanding-Vision-Tasks/9acea841632f33a71ebc68cecbd1a13db9c21ba2/paper.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | tensorboard 3 | fairscale 4 | sentencepiece 5 | gradio==4.19.0 6 | packaging 7 | transformers>=4.43.3 8 | pyyaml 9 | pathlib 10 | Ninja 11 | bitsandbytes 12 | httpx[socks] 13 | einops 14 | regex 15 | h5py 16 | accelerate 17 | pre-commit 18 | --------------------------------------------------------------------------------