├── Pigeon ├── models │ ├── openai_clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── simple_tokenizer.py │ │ ├── clip.py │ │ └── model.py │ ├── transform.py │ ├── trainer │ │ ├── trainer_callback.py │ │ └── optimization.py │ ├── __init__.py │ ├── mutual_mask_generator.py │ ├── modeling_decoder.py │ ├── lavit_for_understanding.py │ └── modeling_visual_tokenzier.py ├── run_sft.sh ├── run_dpo.sh ├── run_inf.sh ├── data_utils.py ├── finetune_sft.py ├── data_utils_dpo.py └── finetune_dpo.py ├── figures └── overview.png ├── Evaluation ├── run_cal_scores.sh ├── run_eval_pigeon.sh ├── select_scores.ipynb ├── select_scores_dpo.ipynb └── cal_scores.py ├── env ├── env_dpo.yml └── env_sft.yml └── README.md /Pigeon/models/openai_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .model import * -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiyanXu/Pigeon/HEAD/figures/overview.png -------------------------------------------------------------------------------- /Pigeon/models/openai_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiyanXu/Pigeon/HEAD/Pigeon/models/openai_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /Evaluation/run_cal_scores.sh: -------------------------------------------------------------------------------- 1 | LOG_DIR="./log/" 2 | 3 | for scenario in sticker 4 | do 5 | for mode in test 6 | do 7 | for ckpt in 375 8 | do 9 | for scale_llama in 1.0 10 | do 11 | for scale_dm in 7.0 12 | do 13 | echo "scenario:$scenario, mode:$mode, ckpt:$ckpt, scale_llama:$scale_llama, scale_dm:$scale_dm," 14 | LOG_FILE="$LOG_DIR/cal_scores_DPO_${scenario}_${mode}_ckpt${ckpt}_scale_llama${scale_llama}_scale_dm${scale_dm}.txt" 15 | CUDA_VISIBLE_DEVICES=3 python cal_scores.py \ 16 | --output_dir /checkpoints/sticker/DPO/ \ 17 | --data_path /datasets/SER-30K/processed_seq/ \ 18 | --img_folder_path /datasets/SER-30K/Images/ \ 19 | --batch_size 50 \ 20 | --mode $mode \ 21 | --ckpt $ckpt \ 22 | --scale_for_llm $scale_llama \ 23 | --scale_for_dm $scale_dm \ 24 | --seed 123 \ 25 | > "$LOG_FILE" 2>&1 26 | done 27 | done 28 | done 29 | done 30 | done 31 | 32 | # nohup sh run_cal_scores.sh >log_cal_scores.txt 2>&1 & 33 | -------------------------------------------------------------------------------- /Pigeon/models/transform.py: -------------------------------------------------------------------------------- 1 | import re 2 | from torchvision import transforms 3 | from torchvision.transforms.functional import InterpolationMode 4 | 5 | 6 | class LaVITImageProcessor: 7 | def __init__(self, image_size=224): 8 | mean = (0.48145466, 0.4578275, 0.40821073) 9 | std = (0.26862954, 0.26130258, 0.27577711) 10 | 11 | transform_list = [ 12 | transforms.Resize( 13 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 14 | ), 15 | transforms.ToTensor(), 16 | transforms.Normalize(mean, std), 17 | ] 18 | 19 | self.transform = transforms.Compose(transform_list) 20 | 21 | def __call__(self, item): 22 | return self.transform(item) 23 | 24 | 25 | class LaVITQuestionProcessor: 26 | """ 27 | Adapting from BLIP2, for processing the question in VQA tasks 28 | """ 29 | def __init__(self, max_words=50): 30 | self.max_words = max_words 31 | 32 | def __call__(self, question): 33 | return self.pre_question(question) 34 | 35 | def pre_question(self, question): 36 | question = re.sub( 37 | r"([.!\"()*#:;~])", 38 | "", 39 | question.lower(), 40 | ) 41 | question = question.rstrip(" ") 42 | 43 | # truncate question 44 | question_words = question.split(" ") 45 | if len(question_words) > self.max_words: 46 | question = " ".join(question_words[: self.max_words]) 47 | 48 | return question -------------------------------------------------------------------------------- /Pigeon/run_sft.sh: -------------------------------------------------------------------------------- 1 | LOG_DIR="./log" 2 | 3 | for seed in 123 4 | do 5 | for lr in 1e-5 6 | do 7 | for dropout in 0.05 8 | do 9 | echo "lr: $lr, dropout: $dropout , seed: $seed," 10 | LOG_FILE="$LOG_DIR/finetune_sft_seed${seed}_lr${lr}_dropout${dropout}.txt" 11 | CUDA_VISIBLE_DEVICES=4 python finetune_sft.py \ 12 | --model_path /path/to/LaVIT-7B-v2/ \ 13 | --model_dtype bf16 \ 14 | --output_dir /checkpoints/sticker/SFT/ \ 15 | --use_xformers \ 16 | --load_in_8bit \ 17 | --pixel_decoding highres \ 18 | --mask_type mutual \ 19 | --num_heads 4 \ 20 | --num_layers 1 \ 21 | --drop_prob 0.2 \ 22 | --hist_mask_ratio 0.2 \ 23 | --scenario sticker \ 24 | --data_path /dataset/SER-30K/processed_seq/ \ 25 | --img_folder_path /dataset/SER-30K/Images/ \ 26 | --batch_size 16 \ 27 | --micro_batch_size 16 \ 28 | --num_epochs 50 \ 29 | --learning_rate $lr \ 30 | --lr_schedule_type linear \ 31 | --min_learning_rate 1e-6 \ 32 | --lora_r 8 \ 33 | --lora_alpha 16\ 34 | --lora_dropout $dropout \ 35 | --lora_target_modules '[q_proj,v_proj]' \ 36 | --seed $seed \ 37 | --resume_from_checkpoint \ 38 | --logging_steps 25 \ 39 | --eval_steps 25 \ 40 | --save_steps 25 \ 41 | --eval_num 1000 \ 42 | > "$LOG_FILE" 2>&1 43 | done 44 | done 45 | done 46 | 47 | # nohup sh run_sft.sh >log_sft.txt 2>&1 & 48 | -------------------------------------------------------------------------------- /Evaluation/run_eval_pigeon.sh: -------------------------------------------------------------------------------- 1 | LOG_DIR="./log/" 2 | 3 | for scenario in sticker 4 | do 5 | for mode in test 6 | do 7 | for ckpt in 375 8 | do 9 | for mask_ratio in 2.0 # mask_ratio=2.0 denotes the selected images 10 | do 11 | for scale_llama in 1.0 12 | do 13 | for scale_dm in 7.0 14 | do 15 | echo "scenario:$scenario, mode:$mode, ckpt:$ckpt, mask_ratio:$mask_ratio, scale_llama:$scale_llama, scale_dm:$scale_dm," 16 | LOG_FILE="$LOG_DIR/eval_DPO_${scenario}_${mode}_ckpt${ckpt}_histmask${mask_ratio}_scale_llama${scale_llama}_scale_dm${scale_dm}.txt" 17 | CUDA_VISIBLE_DEVICES=3 python evaluate_pigeon.py \ 18 | --output_dir /checkpoints/sticker/DPO/ \ 19 | --data_path /datasets/SER-30K/processed_seq/ \ 20 | --img_folder_path /datasets/SER-30K/Images/ \ 21 | --dino_model_path /path/to/DINO/models--facebook--dinov2-large \ 22 | --batch_size 50 \ 23 | --clip_batch_size 512 \ 24 | --dataset $dataset \ 25 | --mode $mode \ 26 | --ckpt $ckpt \ 27 | --with_mask \ 28 | --scale_for_llama $scale_llama \ 29 | --scale_for_dm $scale_dm \ 30 | --mask_ratio $mask_ratio \ 31 | --seed 123 \ 32 | > "$LOG_FILE" 2>&1 33 | done 34 | done 35 | done 36 | done 37 | done 38 | done 39 | 40 | # nohup sh run_eval_pigeon.sh >log_eval_pigeon.txt 2>&1 & 41 | -------------------------------------------------------------------------------- /Pigeon/run_dpo.sh: -------------------------------------------------------------------------------- 1 | LOG_DIR="./log" 2 | 3 | for seed in 123 4 | do 5 | for lr in 5e-6 6 | do 7 | for dropout in 0.05 8 | do 9 | echo "lr: $lr, dropout: $dropout , seed: $seed," 10 | LOG_FILE="$LOG_DIR/finetune_dpo_seed${seed}_lr${lr}_dropout${dropout}.txt" 11 | CUDA_VISIBLE_DEVICES=4 python finetune_dpo.py \ 12 | --model_path /path/to/LaVIT-7B-v2/ \ 13 | --model_dtype bf16 \ 14 | --output_dir /checkpoints/sticker/DPO/ \ 15 | --pre_ckpt /checkpoints/sticker/SFT/checkpoint-2975/ \ 16 | --use_xformers \ 17 | --load_in_8bit \ 18 | --pixel_decoding highres \ 19 | --mask_type mutual \ 20 | --num_heads 4 \ 21 | --num_layers 1 \ 22 | --drop_prob 0.2 \ 23 | --hist_mask_ratio 0.2 \ 24 | --scenario sticker \ 25 | --data_path /dataset/SER-30K/processed_seq/ \ 26 | --img_folder_path /dataset/SER-30K/Images/ \ 27 | --mode dpo \ 28 | --batch_size 8 \ 29 | --micro_batch_size 8 \ 30 | --num_epochs 50 \ 31 | --learning_rate $lr \ 32 | --lr_schedule_type linear \ 33 | --min_learning_rate 1e-6 \ 34 | --lora_r 8 \ 35 | --lora_alpha 16\ 36 | --lora_dropout $dropout \ 37 | --lora_target_modules '[q_proj,v_proj]' \ 38 | --seed $seed \ 39 | --resume_from_checkpoint \ 40 | --logging_steps 25 \ 41 | --eval_steps 25 \ 42 | --save_steps 25 \ 43 | --eval_num 200 \ 44 | > "$LOG_FILE" 2>&1 45 | done 46 | done 47 | done 48 | 49 | # nohup sh run_dpo.sh >log_dpo.txt 2>&1 & 50 | -------------------------------------------------------------------------------- /Pigeon/models/trainer/trainer_callback.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from transformers.utils import logging 3 | from transformers.trainer_callback import TrainerCallback 4 | from transformers.trainer_utils import has_length 5 | 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | class CustomProgressCallback(TrainerCallback): 10 | """ 11 | A [`TrainerCallback`] that displays the progress of training or evaluation. 12 | """ 13 | 14 | def __init__(self): 15 | self.training_bar = None 16 | self.prediction_bar = None 17 | 18 | def on_train_begin(self, args, state, control, **kwargs): 19 | if state.is_local_process_zero: 20 | self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True) 21 | self.current_step = 0 22 | 23 | def on_step_end(self, args, state, control, **kwargs): 24 | if state.is_local_process_zero: 25 | self.training_bar.update(state.global_step - self.current_step) 26 | self.current_step = state.global_step 27 | 28 | def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): 29 | if state.is_local_process_zero and has_length(eval_dataloader): 30 | if self.prediction_bar is None: 31 | self.prediction_bar = tqdm( 32 | total=len(eval_dataloader), leave=True, dynamic_ncols=True 33 | ) # Only set `leave=True` for better view of progress bar in log. 34 | self.prediction_bar.update(1) 35 | 36 | def on_evaluate(self, args, state, control, **kwargs): 37 | if state.is_local_process_zero: 38 | if self.prediction_bar is not None: 39 | self.prediction_bar.close() 40 | self.prediction_bar = None 41 | 42 | def on_predict(self, args, state, control, **kwargs): 43 | if state.is_local_process_zero: 44 | if self.prediction_bar is not None: 45 | self.prediction_bar.close() 46 | self.prediction_bar = None 47 | 48 | def on_log(self, args, state, control, logs=None, **kwargs): 49 | if state.is_local_process_zero and self.training_bar is not None: 50 | _ = logs.pop("total_flos", None) 51 | self.training_bar.write(str(logs)) 52 | 53 | def on_train_end(self, args, state, control, **kwargs): 54 | if state.is_local_process_zero: 55 | self.training_bar.close() 56 | self.training_bar = None 57 | -------------------------------------------------------------------------------- /Pigeon/run_inf.sh: -------------------------------------------------------------------------------- 1 | LOG_DIR="./log" 2 | 3 | for mode in test 4 | do 5 | for checkpoint in 375 6 | do 7 | for mask_ratio in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 8 | do 9 | for llama_scale in 1.0 10 | do 11 | for dm_scale in 7.0 12 | do 13 | echo "Running Inference on checkpoint-$checkpoint, mask_ratio: $mask_ratio, llama_scale: $llama_scale , dm_scale: $dm_scale," 14 | LOG_FILE="$LOG_DIR/inference_${mode}_llama_ckpt${checkpoint}_mask_ratio${mask_ratio}_llama_scale${llama_scale}_dm_scale${dm_scale}.txt" 15 | CUDA_VISIBLE_DEVICES=3 python inference.py \ 16 | --model_path /path/to/LaVIT-7B-v2/ \ 17 | --model_dtype bf16 \ 18 | --output_dir /checkpoints/sticker/DPO/ \ 19 | --pre_ckpt /checkpoints/sticker/SFT/checkpoint-2975/ \ 20 | --mode $mode \ 21 | --use_xformers \ 22 | --load_in_8bit \ 23 | --pixel_decoding highres \ 24 | --mask_type mutual \ 25 | --mask_ratio $mask_ratio \ 26 | --hist_mask_ratio 0.2 \ 27 | --with_mask \ 28 | --scenario sticker \ 29 | --data_path /dataset/SER-30K/processed_seq/ \ 30 | --img_folder_path /dataset/SER-30K/Images/ \ 31 | --batch_size 3 \ 32 | --dm_batch_size 4 \ 33 | --seed 123 \ 34 | --resume_from_checkpoint $checkpoint \ 35 | --use_nucleus_sampling \ 36 | --top_p 1.0 \ 37 | --top_k 50 \ 38 | --temperature 1 \ 39 | --num_beams 4 \ 40 | --min_length 20 \ 41 | --length_penalty 1 \ 42 | --num_return_sequences 1 \ 43 | --guidance_scale_for_llm $llama_scale \ 44 | --ratio 1:1 \ 45 | --guidance_scale_for_dm $dm_scale \ 46 | --num_inference_steps 25 \ 47 | --num_return_images 1 \ 48 | > "$LOG_FILE" 2>&1 49 | done 50 | done 51 | done 52 | done 53 | done 54 | 55 | # nohup sh run_inf.sh >log_inf.txt 2>&1 & 56 | -------------------------------------------------------------------------------- /Pigeon/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .lavit_for_generation import LaVITforGeneration 2 | from .lavit_for_understanding import LaVITforUnderstanding 3 | from .transform import LaVITImageProcessor, LaVITQuestionProcessor 4 | from models.lavit_utils import convert_weights_to_bf16, convert_weights_to_fp16 5 | from huggingface_hub import snapshot_download 6 | 7 | 8 | # Building the Model 9 | def build_model( 10 | model_path='./', 11 | model_dtype='bf16', 12 | device_id=None, 13 | image_size=224, 14 | use_xformers=False, 15 | understanding=True, 16 | load_tokenizer=True, 17 | pixel_decoding='highres', 18 | check_safety=False, 19 | local_files_only=False, 20 | ): 21 | """ 22 | model_path (str): The local directory for the saving the model weight 23 | model_dtype (str): The precision dtype of the model in inference, bf16 or fp16 24 | device_id (int): Specifying the GPU ID to loading the model 25 | use_xformers (bool): default=False, If set True, use xformers to save the GPU memory in the eva clip 26 | understanding (bool): If set True, use LaVIT for multi-modal understanding, else used for generation 27 | load_tokenizer (bool): Whether to load the tokenizer encoder during the image generation. For text-to-image generation, 28 | The visual tokenizer is not needed, set it to `False` for saving the GPU memory. When using for the 29 | multi-modal synthesis (the input image needs to be tokenizd to dircrete ids), the load_tokenizer must be set to True. 30 | pixel_decoding (str): [highres | lowres]: default is `highres`: using the high resolution decoding 31 | for generating high-quality images, if set to `lowres`, using the origin decoder to generate 512 x 512 image 32 | check_safety (bool): Should be set to True to enable the image generation safety check 33 | local_files_only (bool): If you have already downloaded the LaVIT checkpoint to the model_path, 34 | set the local_files_only=True to avoid loading from remote 35 | """ 36 | # Downloading the model checkpoint from the huggingface remote 37 | print("Downloading the LaVIT checkpoint from huggingface") 38 | 39 | if not local_files_only: 40 | snapshot_download("rain1011/LaVIT-7B-v2", local_dir=model_path, 41 | local_files_only=local_files_only, local_dir_use_symlinks=False) 42 | 43 | if understanding: 44 | lavit = LaVITforUnderstanding(model_path=model_path, model_dtype=model_dtype, 45 | device_id=device_id, use_xformers=use_xformers) 46 | else: 47 | lavit = LaVITforGeneration(model_path=model_path, model_dtype=model_dtype, device_id=device_id, 48 | use_xformers=use_xformers, check_safety=check_safety, load_tokenizer=load_tokenizer, pixel_decoding=pixel_decoding) 49 | 50 | # Convert the model parameters to the defined precision 51 | if model_dtype == 'bf16': 52 | convert_weights_to_bf16(lavit) 53 | if model_dtype == 'fp16': 54 | convert_weights_to_fp16(lavit) 55 | 56 | lavit = lavit.eval() 57 | 58 | return lavit -------------------------------------------------------------------------------- /env/env_dpo.yml: -------------------------------------------------------------------------------- 1 | name: dpo 2 | channels: 3 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/main 4 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/cloud/conda-forge 5 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/main/ 6 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/free/ 7 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/cloud/conda-forge/ 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_kmp_llvm 11 | - libgcc-ng=12.2.0=h65d4601_19 12 | - libstdcxx-ng=12.2.0=h46fd767_19 13 | - ca-certificates=2024.7.2=h06a4308_0 14 | - ld_impl_linux-64=2.38=h1181459_1 15 | - libffi=3.3=he6710b0_2 16 | - llvm-openmp=14.0.6=h9e868ea_0 17 | - ncurses=6.4=h6a678d5_0 18 | - openssl=1.1.1w=h7f8727e_0 19 | - pip=24.0=py38h06a4308_0 20 | - python=3.8.13=haa1d7c7_1 21 | - readline=8.2=h5eee18b_0 22 | - setuptools=69.5.1=py38h06a4308_0 23 | - sqlite=3.45.3=h5eee18b_0 24 | - tk=8.6.14=h39e8969_0 25 | - wheel=0.43.0=py38h06a4308_0 26 | - xz=5.4.6=h5eee18b_1 27 | - zlib=1.2.13=h5eee18b_1 28 | - pip: 29 | - accelerate==0.33.0 30 | - aiohappyeyeballs==2.4.0 31 | - aiohttp==3.10.5 32 | - aiosignal==1.3.1 33 | - asttokens==2.4.1 34 | - async-timeout==4.0.3 35 | - attrs==24.2.0 36 | - backcall==0.2.0 37 | - bitsandbytes==0.42.0 38 | - certifi==2024.7.4 39 | - charset-normalizer==3.3.2 40 | - comm==0.2.2 41 | - datasets==3.0.0 42 | - debugpy==1.8.5 43 | - decorator==5.1.1 44 | - diffusers==0.27.2 45 | - dill==0.3.8 46 | - docstring_parser==0.16 47 | - eval_type_backport==0.2.0 48 | - executing==2.0.1 49 | - filelock==3.15.4 50 | - fire==0.6.0 51 | - frozenlist==1.4.1 52 | - fsspec==2024.6.1 53 | - huggingface-hub==0.24.5 54 | - idna==3.7 55 | - importlib_metadata==8.2.0 56 | - ipykernel==6.29.5 57 | - ipython==8.12.3 58 | - jedi==0.19.1 59 | - jinja2==3.1.3 60 | - jupyter_client==8.6.2 61 | - jupyter_core==5.7.2 62 | - markdown-it-py==3.0.0 63 | - markupsafe==2.1.5 64 | - matplotlib-inline==0.1.7 65 | - mdurl==0.1.2 66 | - mpmath==1.3.0 67 | - multidict==6.1.0 68 | - multiprocess==0.70.16 69 | - nest-asyncio==1.6.0 70 | - networkx==3.0 71 | - numpy==1.24.4 72 | - nvidia-cublas-cu11==11.11.3.6 73 | - nvidia-cuda-cupti-cu11==11.8.87 74 | - nvidia-cuda-nvrtc-cu11==11.8.89 75 | - nvidia-cuda-runtime-cu11==11.8.89 76 | - nvidia-cudnn-cu11==8.7.0.84 77 | - nvidia-cufft-cu11==10.9.0.58 78 | - nvidia-curand-cu11==10.3.0.86 79 | - nvidia-cusolver-cu11==11.4.1.48 80 | - nvidia-cusparse-cu11==11.7.5.86 81 | - nvidia-nccl-cu11==2.20.5 82 | - nvidia-nvtx-cu11==11.8.86 83 | - packaging==24.1 84 | - pandas==2.0.3 85 | - parso==0.8.4 86 | - peft==0.7.1 87 | - pexpect==4.9.0 88 | - pickleshare==0.7.5 89 | - pillow==10.2.0 90 | - platformdirs==4.2.2 91 | - prompt_toolkit==3.0.47 92 | - psutil==6.0.0 93 | - ptyprocess==0.7.0 94 | - pure_eval==0.2.3 95 | - pyarrow==17.0.0 96 | - pygments==2.18.0 97 | - python-dateutil==2.9.0.post0 98 | - pytz==2024.2 99 | - pyyaml==6.0.1 100 | - pyzmq==26.1.0 101 | - regex==2024.7.24 102 | - requests==2.32.3 103 | - rich==13.8.1 104 | - safetensors==0.4.3 105 | - scipy==1.10.1 106 | - sentencepiece==0.2.0 107 | - shtab==1.7.1 108 | - six==1.16.0 109 | - stack-data==0.6.3 110 | - sympy==1.12 111 | - termcolor==2.4.0 112 | - timm==1.0.9 113 | - tokenizers==0.19.1 114 | - torch==2.3.0+cu118 115 | - torchvision==0.18.0+cu118 116 | - tornado==6.4.1 117 | - tqdm==4.66.5 118 | - traitlets==5.14.3 119 | - transformers==4.43.3 120 | - triton==2.3.0 121 | - trl==0.10.1 122 | - typing_extensions==4.12.2 123 | - tyro==0.8.10 124 | - tzdata==2024.1 125 | - urllib3==2.2.2 126 | - wcwidth==0.2.13 127 | - xformers==0.0.26.post1 128 | - xxhash==3.5.0 129 | - yarl==1.11.1 130 | - zipp==3.19.2 131 | prefix: /home/xuyiy/anaconda3/envs/dpo 132 | -------------------------------------------------------------------------------- /Pigeon/models/openai_clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /Pigeon/models/mutual_mask_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class PositionalEncoding(nn.Module): 11 | 12 | def __init__(self, d_model, max_len=257): 13 | super(PositionalEncoding, self).__init__() 14 | 15 | pe = torch.zeros(max_len, d_model) 16 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | self.register_buffer('pe', pe) 21 | 22 | def forward(self, pos): 23 | return self.pe[pos] 24 | 25 | class MutualMaskGenerator(nn.Module): 26 | def __init__(self, emb_dim, num_heads, num_layers, drop_prob=0.2, eps=1e-5): 27 | super().__init__() 28 | 29 | self.emb_dim = emb_dim 30 | self.num_heads = num_heads 31 | self.num_layers = num_layers 32 | encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, batch_first=True) 33 | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 34 | self.position_embedding = PositionalEncoding(emb_dim, max_len=257) 35 | self.type_embedding = nn.Embedding(3, emb_dim) # three types: history, target, pad 36 | self.dense = nn.Linear(emb_dim, emb_dim) 37 | self.activation = nn.Tanh() 38 | 39 | self.LayerNorm = nn.LayerNorm(emb_dim, eps=eps) 40 | self.dropout = nn.Dropout(drop_prob) 41 | 42 | def get_input_embeds( 43 | self, 44 | token_embeds: Optional[torch.FloatTensor] = None, 45 | position_ids: Optional[torch.LongTensor] = None, 46 | type_ids: Optional[torch.LongTensor] = None, 47 | ): 48 | position_embeds = self.position_embedding(position_ids) 49 | type_embeds = self.type_embedding(type_ids) 50 | embeds = token_embeds + position_embeds + type_embeds 51 | embeds = self.LayerNorm(embeds) 52 | embeds = self.dropout(embeds) 53 | 54 | return embeds 55 | 56 | def get_output_embeds( 57 | self, 58 | input_embeds: Optional[torch.FloatTensor] = None, 59 | attention_mask: Optional[torch.LongTensor] = None, 60 | ): 61 | output_embeds = self.encoder(input_embeds, src_key_padding_mask=attention_mask) 62 | output_embeds = self.LayerNorm(output_embeds) 63 | output_embeds = self.dense(output_embeds) 64 | output_embeds = self.activation(output_embeds) 65 | output_embeds = self.dropout(output_embeds) 66 | 67 | return output_embeds 68 | 69 | def forward(self, token_embeds, attention_mask, position_ids, type_ids, mask_ratio, hist_mask_ratio, inference): 70 | input_embeds = self.get_input_embeds(token_embeds, position_ids, type_ids) 71 | output_embeds = self.get_output_embeds(input_embeds, attention_mask) 72 | 73 | # Compute the cosine similarity between token_embeds and output_embeds 74 | cos_sim = F.cosine_similarity(token_embeds, output_embeds, dim=-1) 75 | 76 | # Keep history tokens with higher cosine similarity as target-relevant user preference 77 | hist_mask = (type_ids == 1) 78 | hist_indices = Grouped_indices(hist_mask) 79 | hist_keep_prob = Masked_softmax(cos_sim, hist_mask, dim=-1) 80 | hist_hard_keep_decision = Gumbel_softmax_topk(hist_keep_prob, hist_indices, hist_mask_ratio) 81 | 82 | 83 | target_mask = (type_ids == 2) 84 | target_indices = Grouped_indices(target_mask) 85 | if inference: # Mask target tokens with lower cosine similarity for inference 86 | target_keep_prob = Masked_softmax(cos_sim, target_mask, dim=-1) 87 | else: # Mask target tokens with higher cosine similarity for training 88 | target_keep_prob = 1 - Masked_softmax(cos_sim, target_mask, dim=-1) 89 | target_hard_keep_decision = Gumbel_softmax_topk(target_keep_prob, target_indices, mask_ratio) 90 | 91 | return hist_hard_keep_decision + target_hard_keep_decision 92 | 93 | def Gumbel_softmax_topk(keep_prob, valid_pos, mask_ratio, tau=1.0): 94 | bsz, _ = keep_prob.shape 95 | mask = torch.zeros_like(keep_prob) 96 | 97 | for i in range(bsz): 98 | prob = keep_prob[i] 99 | pos = valid_pos[i] 100 | n = len(pos) 101 | if n == 0: 102 | continue 103 | keep_num = max(1, int(n * (1 - mask_ratio))) 104 | 105 | logits = prob[pos].log() 106 | gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1) 107 | gumbels = (logits + gumbels) / tau 108 | y_soft = F.softmax(gumbels, dim=-1) 109 | 110 | # Return topk hard keep decision 111 | topk_indices = y_soft.topk(keep_num, dim=-1).indices 112 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) 113 | y_hard[topk_indices] = 1.0 114 | ret = y_hard - y_soft.detach() + y_soft 115 | 116 | mask[i, pos] = ret 117 | 118 | return mask 119 | 120 | def Masked_softmax(score, mask, dim=-1): 121 | score = torch.exp(score) 122 | mask = mask.to(dtype=score.dtype) 123 | masked_score = score * mask 124 | 125 | return masked_score / masked_score.sum(dim, keepdim=True) 126 | 127 | def Grouped_indices(mask): 128 | indices = mask.nonzero(as_tuple=False) 129 | grouped_indices = [] 130 | for i in range(mask.size(0)): 131 | row_indices = indices[indices[:, 0] == i][:, 1] 132 | grouped_indices.append(row_indices) 133 | 134 | return grouped_indices 135 | 136 | class MLP(nn.Module): 137 | def __init__(self, in_dim, hid_dim, out_dim, drop_prob=0.1): 138 | super().__init__() 139 | 140 | self.encoder = nn.Linear(in_dim, hid_dim) 141 | self.drop = nn.Dropout(drop_prob) 142 | self.decoder = nn.Linear(hid_dim, out_dim) 143 | self.activation = nn.Tanh() 144 | 145 | def forward(self, x): 146 | x = self.encoder(x) 147 | x = self.drop(x) 148 | x = self.decoder(x) 149 | x = self.activation(x) 150 | x = self.drop(x) 151 | 152 | return x 153 | -------------------------------------------------------------------------------- /Pigeon/models/trainer/optimization.py: -------------------------------------------------------------------------------- 1 | ### Copy from https://github.com/AILab-CVC/SEED/MultiModalLLM/src/train/optimization.py 2 | import math 3 | import warnings 4 | from functools import partial 5 | from typing import Callable, Iterable, Optional, Tuple, Union 6 | 7 | import torch 8 | from torch import nn 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau 11 | from transformers.trainer_utils import SchedulerType 12 | from transformers.utils import logging 13 | 14 | from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, get_constant_schedule, get_constant_schedule_with_warmup, get_inverse_sqrt_schedule, get_reduce_on_plateau_schedule 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | def _get_cosine_schedule_with_warmup_lr_lambda(current_step: int, 20 | *, 21 | num_warmup_steps: int, 22 | num_training_steps: int, 23 | num_cycles: float, 24 | min_lr_ratio: float = 0.0): 25 | if current_step < num_warmup_steps: 26 | return float(current_step) / float(max(1, num_warmup_steps)) 27 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 28 | 29 | return max(0.0, 30 | 0.5 * ((1.0 + min_lr_ratio) + (1.0 - min_lr_ratio) * math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 31 | 32 | 33 | def get_cosine_schedule_with_warmup(optimizer: Optimizer, 34 | num_warmup_steps: int, 35 | num_training_steps: int, 36 | num_cycles: float = 0.5, 37 | last_epoch: int = -1, 38 | min_lr_ratio: float = 0.0): 39 | """ 40 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 41 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 42 | initial lr set in the optimizer. 43 | 44 | Args: 45 | optimizer ([`~torch.optim.Optimizer`]): 46 | The optimizer for which to schedule the learning rate. 47 | num_warmup_steps (`int`): 48 | The number of steps for the warmup phase. 49 | num_training_steps (`int`): 50 | The total number of training steps. 51 | num_cycles (`float`, *optional*, defaults to 0.5): 52 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 53 | following a half-cosine). 54 | last_epoch (`int`, *optional*, defaults to -1): 55 | The index of the last epoch when resuming training. 56 | 57 | Return: 58 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 59 | """ 60 | 61 | lr_lambda = partial( 62 | _get_cosine_schedule_with_warmup_lr_lambda, 63 | num_warmup_steps=num_warmup_steps, 64 | num_training_steps=num_training_steps, 65 | num_cycles=num_cycles, 66 | min_lr_ratio=min_lr_ratio, 67 | ) 68 | return LambdaLR(optimizer, lr_lambda, last_epoch) 69 | 70 | 71 | TYPE_TO_SCHEDULER_FUNCTION = { 72 | SchedulerType.LINEAR: get_linear_schedule_with_warmup, 73 | SchedulerType.COSINE: get_cosine_schedule_with_warmup, 74 | SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, 75 | SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, 76 | SchedulerType.CONSTANT: get_constant_schedule, 77 | SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, 78 | SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, 79 | SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, 80 | } 81 | 82 | 83 | def get_scheduler( 84 | name: Union[str, SchedulerType], 85 | optimizer: Optimizer, 86 | num_warmup_steps: Optional[int] = None, 87 | num_training_steps: Optional[int] = None, 88 | min_lr_ratio: Optional[float] = 0.0, 89 | ): 90 | """ 91 | Unified API to get any scheduler from its name. 92 | 93 | Args: 94 | name (`str` or `SchedulerType`): 95 | The name of the scheduler to use. 96 | optimizer (`torch.optim.Optimizer`): 97 | The optimizer that will be used during training. 98 | num_warmup_steps (`int`, *optional*): 99 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 100 | optional), the function will raise an error if it's unset and the scheduler type requires it. 101 | num_training_steps (`int``, *optional*): 102 | The number of training steps to do. This is not required by all schedulers (hence the argument being 103 | optional), the function will raise an error if it's unset and the scheduler type requires it. 104 | """ 105 | name = SchedulerType(name) 106 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 107 | if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU: 108 | return schedule_func(optimizer) 109 | 110 | # All other schedulers require `num_warmup_steps` 111 | if num_warmup_steps is None: 112 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 113 | 114 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 115 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) 116 | 117 | if name == SchedulerType.INVERSE_SQRT: 118 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) 119 | 120 | # All other schedulers require `num_training_steps` 121 | if num_training_steps is None: 122 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 123 | 124 | if name == SchedulerType.COSINE: 125 | logger.info(f'Initialize lr scheduler with min_lr_ratio: {min_lr_ratio}') 126 | return schedule_func(optimizer, 127 | num_warmup_steps=num_warmup_steps, 128 | num_training_steps=num_training_steps, 129 | min_lr_ratio=min_lr_ratio) 130 | else: 131 | return schedule_func(optimizer, 132 | num_warmup_steps=num_warmup_steps, 133 | num_training_steps=num_training_steps) 134 | -------------------------------------------------------------------------------- /env/env_sft.yml: -------------------------------------------------------------------------------- 1 | name: pigeon 2 | channels: 3 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/cloud/conda-forge 4 | - 5 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/main 6 | - defaults 7 | dependencies: 8 | - cudatoolkit=11.8.0=h6a678d5_0 9 | - cudnn=8.9.2.26=cuda11_0 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_kmp_llvm 12 | - libgcc-ng=12.2.0=h65d4601_19 13 | - libstdcxx-ng=12.2.0=h46fd767_19 14 | - libuuid=2.32.1=h7f98852_1000 15 | - libzlib=1.2.13=h166bdaf_4 16 | - python=3.8.13=ha86cf86_0_cpython 17 | - zlib=1.2.13=h166bdaf_4 18 | - bzip2=1.0.8=h5eee18b_6 19 | - ca-certificates=2024.3.11=h06a4308_0 20 | - ld_impl_linux-64=2.38=h1181459_1 21 | - libffi=3.4.4=h6a678d5_1 22 | - libnsl=2.0.0=h5eee18b_0 23 | - llvm-openmp=14.0.6=h9e868ea_0 24 | - ncurses=6.4=h6a678d5_0 25 | - ninja=1.10.2=h06a4308_5 26 | - ninja-base=1.10.2=hd09550d_5 27 | - openssl=3.0.13=h7f8727e_2 28 | - pip=23.3.1=py38h06a4308_0 29 | - readline=8.2=h5eee18b_0 30 | - setuptools=68.2.2=py38h06a4308_0 31 | - sqlite=3.41.2=h5eee18b_0 32 | - tk=8.6.12=h1ccaba5_0 33 | - wheel=0.41.2=py38h06a4308_0 34 | - xz=5.4.6=h5eee18b_0 35 | - pip: 36 | - accelerate==0.28.0 37 | - aiofiles==23.2.1 38 | - aiohttp==3.9.3 39 | - aiosignal==1.3.1 40 | - altair==5.2.0 41 | - annotated-types==0.6.0 42 | - antlr4-python3-runtime==4.9.3 43 | - anyio==4.3.0 44 | - anykeystore==0.2 45 | - apex==0.1 46 | - appdirs==1.4.4 47 | - argon2-cffi==23.1.0 48 | - argon2-cffi-bindings==21.2.0 49 | - arrow==1.3.0 50 | - asttokens==2.4.1 51 | - async-lru==2.0.4 52 | - async-timeout==4.0.3 53 | - attrs==23.2.0 54 | - babel==2.14.0 55 | - backcall==0.2.0 56 | - beautifulsoup4==4.12.3 57 | - bitsandbytes==0.42.0 58 | - black==24.3.0 59 | - bleach==6.1.0 60 | - blinker==1.7.0 61 | - certifi==2024.2.2 62 | - cffi==1.16.0 63 | - charset-normalizer==3.3.2 64 | - click==8.1.7 65 | - clip-interrogator==0.6.0 66 | - cmake==3.28.4 67 | - colorama==0.4.6 68 | - comm==0.2.2 69 | - contourpy==1.1.1 70 | - cryptacular==1.6.2 71 | - cxxfilt==0.3.0 72 | - cycler==0.12.1 73 | - datasets==2.18.0 74 | - debugpy==1.8.1 75 | - decorator==5.1.1 76 | - defusedxml==0.7.1 77 | - diffusers==0.27.2 78 | - dill==0.3.8 79 | - einops==0.7.0 80 | - exceptiongroup==1.2.0 81 | - executing==2.0.1 82 | - fastapi==0.110.0 83 | - fastjsonschema==2.19.1 84 | - ffmpy==0.3.2 85 | - filelock==3.13.1 86 | - fire==0.6.0 87 | - flask==3.0.3 88 | - fonttools==4.50.0 89 | - fqdn==1.5.1 90 | - frozenlist==1.4.1 91 | - fsspec==2024.2.0 92 | - ftfy==6.2.0 93 | - gradio==4.22.0 94 | - gradio_client==0.13.0 95 | - greenlet==3.0.3 96 | - h11==0.14.0 97 | - httpcore==1.0.4 98 | - httpx==0.27.0 99 | - huggingface-hub==0.21.4 100 | - hupper==1.12.1 101 | - hydra-core==1.3.2 102 | - idna==3.6 103 | - importlib_metadata==7.1.0 104 | - importlib_resources==6.4.0 105 | - iniconfig==2.0.0 106 | - ipykernel==6.29.3 107 | - ipython==8.12.3 108 | - ipywidgets==8.1.2 109 | - isoduration==20.11.0 110 | - itsdangerous==2.1.2 111 | - jedi==0.19.1 112 | - jinja2==3.1.3 113 | - json5==0.9.24 114 | - jsonpointer==2.4 115 | - jsonschema==4.21.1 116 | - jsonschema-specifications==2023.12.1 117 | - jupyter==1.0.0 118 | - jupyter-console==6.6.3 119 | - jupyter-events==0.10.0 120 | - jupyter-lsp==2.2.4 121 | - jupyter_client==8.6.1 122 | - jupyter_core==5.7.2 123 | - jupyter_server==2.13.0 124 | - jupyter_server_terminals==0.5.3 125 | - jupyterlab==4.1.5 126 | - jupyterlab_pygments==0.3.0 127 | - jupyterlab_server==2.25.4 128 | - jupyterlab_widgets==3.0.10 129 | - kiwisolver==1.4.5 130 | - lightning-utilities==0.11.2 131 | - lit==18.1.2 132 | - loralib==0.1.2 133 | - lpips==0.1.4 134 | - markdown-it-py==3.0.0 135 | - markupsafe==2.1.5 136 | - matplotlib==3.7.5 137 | - matplotlib-inline==0.1.6 138 | - mdurl==0.1.2 139 | - mistune==3.0.2 140 | - mpmath==1.3.0 141 | - multidict==6.0.5 142 | - multiprocess==0.70.16 143 | - mypy-extensions==1.0.0 144 | - nbclient==0.10.0 145 | - nbconvert==7.16.3 146 | - nbformat==5.10.3 147 | - nest-asyncio==1.6.0 148 | - networkx==3.1 149 | - notebook==7.1.2 150 | - notebook_shim==0.2.4 151 | - numpy==1.24.4 152 | - nvidia-cublas-cu11==11.11.3.6 153 | - nvidia-cublas-cu12==12.1.3.1 154 | - nvidia-cuda-cupti-cu11==11.8.87 155 | - nvidia-cuda-cupti-cu12==12.1.105 156 | - nvidia-cuda-nvrtc-cu11==11.8.89 157 | - nvidia-cuda-nvrtc-cu12==12.1.105 158 | - nvidia-cuda-runtime-cu11==11.8.89 159 | - nvidia-cuda-runtime-cu12==12.1.105 160 | - nvidia-cudnn-cu11==8.7.0.84 161 | - nvidia-cudnn-cu12==9.1.0.70 162 | - nvidia-cufft-cu11==10.9.0.58 163 | - nvidia-cufft-cu12==11.0.2.54 164 | - nvidia-curand-cu11==10.3.0.86 165 | - nvidia-curand-cu12==10.3.2.106 166 | - nvidia-cusolver-cu11==11.4.1.48 167 | - nvidia-cusolver-cu12==11.4.5.107 168 | - nvidia-cusparse-cu11==11.7.5.86 169 | - nvidia-cusparse-cu12==12.1.0.106 170 | - nvidia-nccl-cu11==2.20.5 171 | - nvidia-nccl-cu12==2.20.5 172 | - nvidia-nvjitlink-cu12==12.1.105 173 | - nvidia-nvtx-cu11==11.8.86 174 | - nvidia-nvtx-cu12==12.1.105 175 | - oauthlib==3.2.2 176 | - omegaconf==2.3.0 177 | - open-clip-torch==2.24.0 178 | - opencv-python==4.10.0.84 179 | - orjson==3.9.15 180 | - overrides==7.7.0 181 | - packaging==24.0 182 | - pandas==2.0.3 183 | - pandocfilters==1.5.1 184 | - parso==0.8.3 185 | - pastedeploy==3.1.0 186 | - pathspec==0.12.1 187 | - pbkdf2==1.3 188 | - peft==0.7.1 189 | - pexpect==4.9.0 190 | - pickleshare==0.7.5 191 | - pillow==10.2.0 192 | - pkgutil_resolve_name==1.3.10 193 | - plaster==1.1.2 194 | - plaster-pastedeploy==1.0.1 195 | - platformdirs==4.2.0 196 | - pluggy==1.5.0 197 | - prometheus_client==0.20.0 198 | - prompt-toolkit==3.0.43 199 | - protobuf==5.26.0 200 | - psutil==5.9.8 201 | - ptyprocess==0.7.0 202 | - pure-eval==0.2.2 203 | - pyarrow==15.0.2 204 | - pyarrow-hotfix==0.6 205 | - pycparser==2.21 206 | - pydantic==2.6.4 207 | - pydantic_core==2.16.3 208 | - pydub==0.25.1 209 | - pygments==2.17.2 210 | - pyparsing==3.1.2 211 | - pyproject==1.3.1 212 | - pyramid==2.0.2 213 | - pyramid-mailer==0.15.1 214 | - pyrootutils==1.0.4 215 | - pytest==8.1.1 216 | - python-dateutil==2.9.0.post0 217 | - python-dotenv==1.0.1 218 | - python-json-logger==2.0.7 219 | - python-multipart==0.0.9 220 | - python3-openid==3.2.0 221 | - pytorch-extension==0.2 222 | - pytorch-fid==0.3.0 223 | - pytorch-msssim==1.0.0 224 | - pytz==2024.1 225 | - pyyaml==6.0.1 226 | - pyzmq==25.1.2 227 | - qtconsole==5.5.1 228 | - qtpy==2.4.1 229 | - referencing==0.34.0 230 | - regex==2023.12.25 231 | - repoze.sendmail==4.4.1 232 | - requests==2.31.0 233 | - requests-oauthlib==2.0.0 234 | - rfc3339-validator==0.1.4 235 | - rfc3986-validator==0.1.1 236 | - rich==13.7.1 237 | - rpds-py==0.18.0 238 | - ruff==0.3.4 239 | - safetensors==0.4.2 240 | - scipy==1.10.1 241 | - semantic-version==2.10.0 242 | - send2trash==1.8.2 243 | - sentencepiece==0.2.0 244 | - shellingham==1.5.4 245 | - six==1.16.0 246 | - sniffio==1.3.1 247 | - soupsieve==2.5 248 | - sqlalchemy==2.0.29 249 | - stack-data==0.6.3 250 | - starlette==0.36.3 251 | - sympy==1.12 252 | - termcolor==2.4.0 253 | - terminado==0.18.1 254 | - timm==0.4.12 255 | - tinycss2==1.2.1 256 | - tokenize-rt==5.2.0 257 | - tokenizers==0.13.3 258 | - tomli==2.0.1 259 | - tomlkit==0.12.0 260 | - toolz==0.12.1 261 | - torch==2.3.0+cu118 262 | - torch-fidelity==0.3.0 263 | - torchmetrics==1.3.2 264 | - torchvision==0.18.0+cu118 265 | - tornado==6.4 266 | - tqdm==4.66.4 267 | - traitlets==5.14.2 268 | - transaction==4.0 269 | - transformers==4.33.2 270 | - translationstring==1.4 271 | - triton==2.3.0 272 | - typer==0.9.0 273 | - types-python-dateutil==2.9.0.20240316 274 | - typing_extensions==4.10.0 275 | - tzdata==2024.1 276 | - uri-template==1.3.0 277 | - urllib3==2.2.1 278 | - uvicorn==0.29.0 279 | - velruse==1.1.1 280 | - venusian==3.1.0 281 | - wcwidth==0.2.13 282 | - webcolors==1.13 283 | - webencodings==0.5.1 284 | - webob==1.8.7 285 | - websocket-client==1.7.0 286 | - websockets==11.0.3 287 | - werkzeug==3.0.2 288 | - widgetsnbextension==4.0.10 289 | - wtforms==3.1.2 290 | - wtforms-recaptcha==0.3.2 291 | - xformers==0.0.27.dev792 292 | - xxhash==3.4.1 293 | - yarl==1.9.4 294 | - zipp==3.18.1 295 | - zope.deprecation==5.0 296 | - zope.interface==6.2 297 | - zope.sqlalchemy==3.1 298 | prefix: /home/xuyiy/anaconda3/envs/pigeon 299 | 300 | -------------------------------------------------------------------------------- /Pigeon/data_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import re 5 | from fileinput import filename 6 | 7 | import numpy as np 8 | import scipy.sparse as sp 9 | from dataclasses import dataclass 10 | from typing import Optional, Union, Mapping, List, Dict, Any 11 | import torch 12 | import torch.utils.data as data 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms 15 | from PIL import Image 16 | from transformers import LlamaTokenizer 17 | from transformers.utils import PaddingStrategy 18 | 19 | class sticker_template: 20 | instruction = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." 21 | task_description = "Instruction: You are a helpful personalized assistant. You will be provided with a list of stickers that the user likes to analyze the user's preference over stickers. Based on this analysis, please design a personalized sticker that aligns with target sticker info and the user's sticker tastes." 22 | prompt_start = "Input: " 23 | history_start = "The user likes the following stickers: " 24 | target_start = "The target sticker info: a sticker with the emotion of . " 25 | prompt_end = "Response: " 26 | 27 | class movie_template: 28 | instruction = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." 29 | task_description = "Instruction: You are a helpful personalized assistant. You will be provided with a list of movies that the user likes, along with the movie posters, to analyze the user's preference over posters. Based on this analysis, please design a personalized poster that aligns with target movie and the user's poster tastes." 30 | prompt_start = "Input: " 31 | history_start = "The user likes the following movies: " 32 | target_start = "The target movie is titled ''. " 33 | prompt_end = "Response: " 34 | 35 | SEP = "\n" 36 | IMG_TOKEN = "<img>" 37 | EMB_TOKEN = "<emb>" 38 | 39 | class ImagePathDataset(Dataset): 40 | def __init__(self, folder_path, paths, trans=None): 41 | self.folder_path = folder_path 42 | self.paths = paths 43 | self.trans = trans 44 | 45 | def __len__(self): 46 | return len(self.paths) 47 | 48 | def __getitem__(self, idx): 49 | path = os.path.join(self.folder_path, self.paths[idx]) 50 | img = Image.open(path).convert('RGB') 51 | if self.trans is not None: 52 | img = self.trans(img) 53 | 54 | return img.to(memory_format=torch.contiguous_format) 55 | 56 | class DiffusionImageDataset(Dataset): 57 | def __init__(self, start, end): 58 | self.data = list(range(start, end)) 59 | 60 | def __len__(self): 61 | return len(self.data) 62 | 63 | def __getitem__(self, idx): 64 | return self.data[idx] 65 | 66 | class PigeonDataset(Dataset): 67 | def __init__(self, data): 68 | self.data = data 69 | 70 | def __len__(self): 71 | return len(self.data["uids"]) 72 | 73 | def __getitem__(self, idx): 74 | text_input_ids = self.data["text_input_ids"][idx] 75 | text_attn_mask = self.data["text_attn_mask"][idx] 76 | image = self.data["image"][idx] 77 | uids = self.data["uids"][idx] 78 | genres = self.data["genres"][idx] 79 | 80 | return { 81 | "text_input_ids": text_input_ids, 82 | "text_attn_mask": text_attn_mask, 83 | "image": image, 84 | "uids": uids, 85 | "genres": genres 86 | } 87 | 88 | def shuffle(self, seed=None): 89 | if seed is not None: 90 | torch.manual_seed(seed) 91 | data_len = len(self.data["uids"]) 92 | indices = torch.randperm(data_len).tolist() 93 | self.data = {key:[self.data[key][i] for i in indices] for key in self.data} 94 | 95 | return self 96 | 97 | def select(self, indices): 98 | selected_data = {key:[self.data[key][i] for i in indices] for key in self.data} 99 | return PigeonDataset(selected_data) 100 | 101 | @dataclass 102 | class PigeonCollator: 103 | pad_token_id: int = 2 104 | padding_side: Optional[str] = "left" 105 | return_tensors: str = "pt" 106 | 107 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 108 | # If we have a list of dicts, let's convert it in a dict of lists 109 | # We do this to allow using this method as a collate_fn function in PyTorch Dataloader 110 | if isinstance(features, (list, tuple)) and isinstance(features[0], Mapping): 111 | features = {key: [example[key] for example in features] for key in features[0].keys()} 112 | 113 | max_length = max(len(input_ids) for input_ids in features["text_input_ids"]) 114 | bsz = len(features["text_input_ids"]) 115 | for i in range(bsz): 116 | pad_num = max_length - len(features["text_input_ids"][i]) 117 | 118 | if self.padding_side == "left": 119 | features["text_input_ids"][i] = [self.pad_token_id] * pad_num + features["text_input_ids"][i] 120 | features["text_attn_mask"][i] = [0] * pad_num + features["text_attn_mask"][i] 121 | else: 122 | features["text_input_ids"][i] = features["text_input_ids"][i] + [self.pad_token_id] * pad_num 123 | features["text_attn_mask"][i] = features["text_attn_mask"][i] + [0] * pad_num 124 | 125 | if self.return_tensors == "pt": 126 | features = {key: torch.tensor(features[key], dtype=torch.long) for key in features.keys()} 127 | 128 | return features 129 | 130 | def process_data( 131 | scenario: str="sticker", 132 | data: dict=None, 133 | item_info: dict=None, 134 | semantics: list=None, 135 | tokenizer: LlamaTokenizer=None, 136 | ): 137 | text_input_ids = [] 138 | text_attn_mask = [] 139 | image = [] 140 | uids = [] 141 | genres = [] 142 | for uid in data: 143 | for genre in data[uid]: 144 | hist_seqs = data[uid][genre] 145 | for hist_seq in hist_seqs: 146 | if scenario == "sticker": 147 | text_prompt = generate_prompt_for_stickers(sticker_template, hist_seq, item_info, semantics) 148 | elif scenario == "movie": 149 | text_prompt = generate_prompt_for_movies(movie_template, hist_seq, item_info, semantics) 150 | prompt_tokens = tokenizer( 151 | text_prompt, padding="longest", return_tensors="pt", add_special_tokens=False 152 | ) 153 | text_input_ids.append(prompt_tokens.input_ids.squeeze(0).tolist()) 154 | text_attn_mask.append(prompt_tokens.attention_mask.squeeze(0).tolist()) 155 | # the last image is the target image, to be the input and output at the same time 156 | image.append(hist_seq) 157 | uids.append(uid) 158 | genres.append(genre) 159 | 160 | return {"text_input_ids": text_input_ids, "text_attn_mask": text_attn_mask, 161 | "image": image, "uids": uids, "genres": genres} 162 | 163 | def generate_prompt_for_movies( 164 | template, 165 | hist_seq: dict=None, 166 | movies_info: dict=None, 167 | semantics: list=None, 168 | ): 169 | hist_prompt = template.history_start 170 | for iid in hist_seq[:-1]: 171 | title = movies_info[iid]["title"] 172 | try: 173 | title = re.findall(r'^(.*) \(\d+\) *$', title)[0] 174 | except: 175 | title = title 176 | 177 | hist_prompt += SEP 178 | hist_prompt += title 179 | hist_prompt += " " 180 | hist_prompt += IMG_TOKEN 181 | 182 | tgt_iid = hist_seq[-1] 183 | target_title = movies_info[tgt_iid]["title"] 184 | try: 185 | target_title = re.findall(r'^(.*) \(\d+\) *$', target_title)[0] 186 | except: 187 | target_title = target_title 188 | 189 | target_prompt = template.target_start.("<title>", target_title) 190 | target_prompt = target_prompt + semantics[tgt_iid] + " " + EMB_TOKEN 191 | 192 | text_prompt = template.instruction + SEP + template.task_description + SEP + template.prompt_start + hist_prompt + SEP + target_prompt + SEP + template.prompt_end 193 | text_prompt += IMG_TOKEN 194 | 195 | return text_prompt 196 | 197 | def generate_prompt_for_stickers( 198 | template, 199 | hist_seq: dict=None, 200 | anno_dict: dict=None, 201 | semantics: list=None, 202 | ): 203 | hist_prompt = template.history_start 204 | for iid in hist_seq[:-1]: 205 | hist_prompt += SEP 206 | hist_prompt += IMG_TOKEN 207 | 208 | tgt_iid = hist_seq[-1] 209 | emo = anno_dict[tgt_iid]["emo"] 210 | 211 | target_prompt = template.target_start.replace("<emo>", emo) 212 | target_prompt = target_prompt + semantics[tgt_iid] + " " + EMB_TOKEN 213 | 214 | text_prompt = template.instruction + SEP + template.task_description + SEP + template.prompt_start + hist_prompt + SEP + target_prompt + SEP + template.prompt_end 215 | text_prompt += IMG_TOKEN 216 | 217 | return text_prompt 218 | 219 | 220 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Personalized Image Generation with Large Multimodal Models 2 | 3 | This is the pytorch implementation of our paper at WWW 2025: 4 | > [Personalized Image Generation with Large Multimodal Models](https://arxiv.org/abs/2410.14170) 5 | > 6 | > Yiyan Xu, Wenjie Wang, Yang Zhang, Biao Tang, Yan Peng, Fuli Feng, Xiangnan He 7 | 8 | ## Pigeon Overview 9 | Pigeon consists of three key modules: 1) mask generation module creates token-level masks for history and reference images, 2) personalized module encodes multimodal instructions and integrates them with masked history to generate personalized tokens, and 3) image generation module utilizes these tokens to produce personalized images. 10 | ![Pigeon Overview](./figures/overview.png) 11 | 12 | ## Environment 13 | - Anaconda 3 14 | - python 3.8.13 15 | - cuda 12.2 16 | - peft 0.7.1 17 | - torch 2.3.0 18 | - torchvision 0.18.0 19 | - Pillow 10.2.0 20 | - numpy 1.24.4 21 | - transformers 4.33.2 22 | - open-clip-torch 2.24.0 23 | - accelerate 0.28.0 24 | - diffusers 0.27.2 25 | - xformers 0.0.27.dev792 26 | - pytorch-fid 0.3.0 27 | - lpips 0.1.4 28 | 29 | ## Usage 30 | ### Dataset 31 | Download the datasets into the "./dataset" folder from [here](https://drive.google.com/drive/folders/1MpvkQ_DCimfcXBZJpPmliNf8s4reUA9B?usp=drive_link), including SER-30K and MovieLens-Latest-small for sticker and movie poster scenarios, respectively. 32 | 33 | #### Sticker 34 | * **data_path:** `./dataset/SER-30K/processed_seq` 35 | 36 | * **`mapped_anno_dict.npy`:** dict of sticker info. 37 | ```python 38 | { 39 | 40 | iid1: { # image id 41 | 'emo': 'Neutral', # sticker emotion, like sadness 42 | 'text': '' # sticker text, most are empty 43 | }, 44 | iid2: ..., 45 | ... 46 | } 47 | ``` 48 | 49 | * **`mapped_user_dict.npy`:** dict of user-interacted seqs, where each user interacts with a single theme. 50 | ```python 51 | { 52 | uid1: { # user id 53 | gid1: [iid1, iid2, ...], # user historically interacted iids 54 | }, 55 | uid2: { 56 | gid2: ..., 57 | }, 58 | ... 59 | } 60 | ``` 61 | 62 | * **`train.npy`, `valid.npy` & `test.npy`:** We apply a sliding window of six interactions, moving one step at a time to create data samples for each user in both scenarios. Each sample treats the first five interactions as the user history images and the last as the target image. We split the samples into training, validation, and testing sets with a ratio of 8:1:1. 63 | ```python 64 | { 65 | uid1: { # user id 66 | gid1: [ # genre id 67 | [iid1, iid2, iid3, iid4, iid5, iid6], # interaction sequence, where the last one serves as the target image and the preceding ones are history images 68 | [...], 69 | ... 70 | ], 71 | }, 72 | uid2: {...}, 73 | ... 74 | } 75 | ``` 76 | 77 | * **`all_image_paths.npy`:** list of image paths arranged according to iids. 78 | * **image_folder_path:** `./dataset/SER30K/Images` 79 | ```python 80 | import numpy as np 81 | from PIL import Image 82 | import os 83 | 84 | img_folder_path = "./dataset/SER-30K/Images" 85 | all_image_paths = np.load("./dataset/SER30K/processed_seq/all_image_paths.npy", allow_pickle=True) 86 | 87 | iid = 100 88 | img_path = os.path.join(img_folder_path, all_image_paths[iid]) 89 | im = Image.open(img_path) 90 | display(im) 91 | ``` 92 | 93 | * **`sticker_semantics.npy`:** list of textual descriptions of each image with the same format as `all_image_paths.npy`. 94 | 95 | * **`all_clip_feats.npy` & `all_dino_feats.npy`:** list of extracted CLIP and DINO features for evaluation with the same format as `all_image_paths.npy`. 96 | 97 | * **`/map`:** maps of image id and user id saved during data preprocessing. 98 | 99 | * **`/data_ready`:** data processed for Pigeon training and evaluation. 100 | * **`train_ready.npy`, `valid_ready.npy` & `test_ready.npy`** 101 | ```python 102 | { 103 | 'uids': [uid1, uid2, ...], 104 | 'genres': [gid1, gid2, ...], 105 | 'image': [ 106 | [iid1, iid2, iid3, iid4, iid5, iid6], # the interaction seq of uid1 107 | ... 108 | ] 109 | 'text_input_ids': [...], 110 | 'text_attn_mask': [...], 111 | } 112 | ``` 113 | 114 | * **`train_hist_embeds.npy`,`valid_hist_embeds.npy` & `test_hist_embeds.npy`:** for evaluation. 115 | ```python 116 | { 117 | uid1: { 118 | target_iid1: { 119 | gid1: ..., # clip features of the target_iid image 120 | }, 121 | target_iid2: {...}, 122 | ... 123 | }, 124 | uid2: ..., 125 | ... 126 | } 127 | ``` 128 | 129 | * **`/dpo`:** preference dataset for the second-stage preference alignment. 130 | * **`train_preference_data.npy` & `valid_preference_data.npy`** 131 | ```python 132 | { 133 | uid1: { 134 | target_iid1: { 135 | gid1: { 136 | 'chosen': ..., # chosen token sequence for <uid, target_iid1> 137 | 'rejected': ... # rejected token sequence for <uid, target_iid1> 138 | } 139 | }, 140 | target_iid2: {...}, 141 | ... 142 | }, 143 | uid2: ..., 144 | ... 145 | } 146 | ``` 147 | 148 | * **`train.npy` & `valid.npy`**: randomly sampled subset of the above `train.npy` & `valid.npy` for the second-stage training. 149 | 150 | * **`/data_ready_dpo`:** data processed for the second-stage DPO. 151 | * **`train_ready.npy` & `valid_ready.npy`** 152 | ```python 153 | { 154 | 'uids': [uid1, uid2, ...], 155 | 'genres': [gid1, gid2, ...], 156 | 'image': [ 157 | [iid1, iid2, iid3, iid4, iid5, iid6], # the interaction seq of uid1 158 | ... 159 | ] 160 | 'text_input_ids': [...], 161 | 'text_attn_mask': [...], 162 | 'chosen_tokens': [...], 163 | 'rejected_tokens': [...] 164 | } 165 | ``` 166 | 167 | #### Movie Poster 168 | * **data_path**: `./dataset/ml-latest-small/processed_seq` 169 | 170 | * **image_folder_path:** `./dataset/ml-latest-small/poster` 171 | 172 | * **`mapped_movies.npy`:** dict of movie info. 173 | ```python 174 | { 175 | iid1: { # image id / movie id 176 | 'title': 'Star Wars: Episode VI - Return of the Jedi (1983)', # movie title 177 | 'genres': 'Action|Adventure|Fantasy', # movie genres 178 | 'intro': ... # movie introduction 179 | }, 180 | iid2: ..., 181 | ... 182 | } 183 | ``` 184 | 185 | * **`user_seq_dict.npy`:** dict of user interaction history with the same format as `mapped_user_dict.npy` in the sticker dataset. 186 | 187 | * **Other data are the same as the sticker dataset.** 188 | 189 | ### Pre-trained Models 190 | Please download the following pre-trained models for fine-tuning and evaluation. 191 | - [LaVIT](https://huggingface.co/rain1011/LaVIT-7B-v2) 192 | - [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 193 | - [DINO-v2](https://huggingface.co/facebook/dinov2-large) (for evaluation) 194 | 195 | ### Training 196 | #### Stage-1: Masked Preference Reconstruction 197 | Please configure the correct paths of all data in `finetune_sft.py` and `run_sft.sh`. 198 | ``` 199 | cd ./Pigeon 200 | sh run_sft.sh 201 | ``` 202 | 203 | #### Stage-2: Pairwise Preference Optimization 204 | Please configure the correct paths of all data in `finetune_dpo.py` and `run_dpo.sh`. And set the appropriate ckpt after the first-stage alignment for the second-stage fine-tuning. 205 | ``` 206 | cd ./Pigeon 207 | sh run_dpo.sh 208 | ``` 209 | 210 | ### Inference 211 | 1. Download the checkpoint released by us from [here](https://drive.google.com/drive/folders/1Hax0ZubvHqaGvUROVzD32_iKPpi8mYh1?usp=drive_link). 212 | 2. Put the checkpoint into the appropriate folder. 213 | 3. Configure correct paths of all data in `inference.py`. 214 | 4. Run inference.py 215 | ``` 216 | cd ./Pigeon 217 | sh run_inf.sh 218 | ``` 219 | 5. Calculate the history CIS and reference CS for each generated target image. 220 | ``` 221 | cd ./Evaluation 222 | sh run_cal_scores.sh 223 | ``` 224 | 6. Run the `./Evaluation/select_scores.ipynb` file to select the optimal reference mask ratio $\alpha_r$ (generated target image). 225 | 226 | ### Evaluation 227 | Run the evaluation code through the `.sh` files. 228 | ``` 229 | cd ./Evaluation 230 | sh run_eval_pigeon.sh 231 | ``` 232 | 233 | ### Preference Dataset Construction 234 | After completing the first stage of fine-tuning, you can follow these steps to construct your own preference dataset for DPO: 235 | 1. Using the first-stage checkpoint, run `./Pigeon/run_inf.sh` to generate multiple personalized target images for both the training and validation sets: 236 | - For the training set: `--mode train` `--eval_num 1000` 237 | - For the validation set: `--mode valid` `--eval_num 200` 238 | 3. Calculate the history CIS and reference CS for each generated target image by executing `./Evaluation/run_cal_scores.sh`. 239 | 4. Run the `./Evaluation/select_scores_dpo.ipynb` file to construct the preference pairs for DPO training. 240 | 241 | ### Acknowledgments 242 | Our work heavily relies on the excellent contributions of [LaVIT](https://github.com/jy0205/LaVIT). We sincerely thank the team for their efforts. 243 | -------------------------------------------------------------------------------- /Pigeon/models/openai_clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 60 | while True: 61 | buffer = source.read(8192) 62 | if not buffer: 63 | break 64 | 65 | output.write(buffer) 66 | loop.update(len(buffer)) 67 | 68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 70 | 71 | return download_target 72 | 73 | 74 | def _convert_image_to_rgb(image): 75 | return image.convert("RGB") 76 | 77 | 78 | def _transform(n_px): 79 | return Compose([ 80 | Resize(n_px, interpolation=BICUBIC), 81 | CenterCrop(n_px), 82 | _convert_image_to_rgb, 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | 88 | def available_models() -> List[str]: 89 | """Returns the names of available CLIP models""" 90 | return list(_MODELS.keys()) 91 | 92 | 93 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 94 | """Load a CLIP model 95 | 96 | Parameters 97 | ---------- 98 | name : str 99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 100 | 101 | device : Union[str, torch.device] 102 | The device to put the loaded model 103 | 104 | jit : bool 105 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 106 | 107 | download_root: str 108 | path to download the model files; by default, it uses "~/.cache/clip" 109 | 110 | Returns 111 | ------- 112 | model : torch.nn.Module 113 | The CLIP model 114 | 115 | preprocess : Callable[[PIL.Image], torch.Tensor] 116 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 117 | """ 118 | print(f"Loading clip model from {name}") 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | try: 127 | # loading JIT archive 128 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 129 | state_dict = None 130 | except RuntimeError: 131 | # loading saved state dict 132 | if jit: 133 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 134 | jit = False 135 | state_dict = torch.load(model_path, map_location="cpu") 136 | 137 | if not jit: 138 | model = build_model(state_dict or model.state_dict()).to(device) 139 | if str(device) == "cpu": 140 | model.float() 141 | return model, _transform(model.visual.input_resolution) 142 | 143 | # patch the device names 144 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 145 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 146 | 147 | def patch_device(module): 148 | try: 149 | graphs = [module.graph] if hasattr(module, "graph") else [] 150 | except RuntimeError: 151 | graphs = [] 152 | 153 | if hasattr(module, "forward1"): 154 | graphs.append(module.forward1.graph) 155 | 156 | for graph in graphs: 157 | for node in graph.findAllNodes("prim::Constant"): 158 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 159 | node.copyAttributes(device_node) 160 | 161 | model.apply(patch_device) 162 | patch_device(model.encode_image) 163 | patch_device(model.encode_text) 164 | 165 | # patch dtype to float32 on CPU 166 | if str(device) == "cpu": 167 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 168 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 169 | float_node = float_input.node() 170 | 171 | def patch_float(module): 172 | try: 173 | graphs = [module.graph] if hasattr(module, "graph") else [] 174 | except RuntimeError: 175 | graphs = [] 176 | 177 | if hasattr(module, "forward1"): 178 | graphs.append(module.forward1.graph) 179 | 180 | for graph in graphs: 181 | for node in graph.findAllNodes("aten::to"): 182 | inputs = list(node.inputs()) 183 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 184 | if inputs[i].node()["value"] == 5: 185 | inputs[i].node().copyAttributes(float_node) 186 | 187 | model.apply(patch_float) 188 | patch_float(model.encode_image) 189 | patch_float(model.encode_text) 190 | 191 | model.float() 192 | 193 | return model, _transform(model.input_resolution.item()) 194 | 195 | 196 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 197 | """ 198 | Returns the tokenized representation of given input string(s) 199 | 200 | Parameters 201 | ---------- 202 | texts : Union[str, List[str]] 203 | An input string or a list of input strings to tokenize 204 | 205 | context_length : int 206 | The context length to use; all CLIP models use 77 as the context length 207 | 208 | truncate: bool 209 | Whether to truncate the text in case its encoding is longer than the context length 210 | 211 | Returns 212 | ------- 213 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 214 | """ 215 | if isinstance(texts, str): 216 | texts = [texts] 217 | 218 | sot_token = _tokenizer.encoder["<|startoftext|>"] 219 | eot_token = _tokenizer.encoder["<|endoftext|>"] 220 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 221 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 222 | 223 | for i, tokens in enumerate(all_tokens): 224 | if len(tokens) > context_length: 225 | if truncate: 226 | tokens = tokens[:context_length] 227 | tokens[-1] = eot_token 228 | else: 229 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 230 | result[i, :len(tokens)] = torch.tensor(tokens) 231 | 232 | return result 233 | -------------------------------------------------------------------------------- /Pigeon/finetune_sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List 4 | import logging 5 | import random 6 | 7 | import inspect 8 | import numpy as np 9 | import fire 10 | import torch 11 | import transformers 12 | from datasets import load_dataset, concatenate_datasets 13 | from transformers import EarlyStoppingCallback 14 | from typing import Optional, Union 15 | """ 16 | Unused imports:` 17 | import torch.nn as nn 18 | import bitsandbytes as bnb 19 | """ 20 | 21 | from peft import ( 22 | LoraConfig, 23 | get_peft_model, 24 | get_peft_model_state_dict, 25 | prepare_model_for_int8_training, 26 | set_peft_model_state_dict, 27 | ) 28 | from transformers import LlamaForCausalLM, LlamaTokenizer 29 | from transformers.data.data_collator import default_data_collator 30 | from models.lavit_utils import get_rank 31 | from models.lavit_for_pigeon import LaVITforPigeon 32 | from models.trainer.trainer_for_pigeon import PigeonTrainer, PigeonTrainingArguments 33 | from models.lavit_utils import convert_weights_to_bf16, convert_weights_to_fp16 34 | import data_utils 35 | 36 | class InfoFilter(logging.Filter): 37 | def filter(self, record): 38 | return record.levelno <= logging.INFO 39 | 40 | def setup_logger(): 41 | logger = logging.getLogger() 42 | logger.setLevel(logging.INFO) 43 | 44 | console_handler = logging.StreamHandler() 45 | console_handler.setLevel(logging.INFO) 46 | 47 | info_filter = InfoFilter() 48 | console_handler.addFilter(info_filter) 49 | 50 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 51 | console_handler.setFormatter(formatter) 52 | 53 | if not logger.handlers: 54 | logger.addHandler(console_handler) 55 | 56 | setup_logger() 57 | logger = logging.getLogger(__name__) 58 | 59 | def train( 60 | model_path: str = "/path/to/pre-trained/LaVIT/", 61 | model_dtype: str = "bf16", 62 | output_dir: str = "/path/to/output/", 63 | # model-specific hyperparams 64 | use_xformers: bool = True, 65 | load_in_8bit: bool = True, 66 | check_safety: bool = False, 67 | pixel_decoding: str = "highres", 68 | # mask generator hyperparams 69 | mask_type: str = "random", 70 | num_heads: int = 4, 71 | num_layers: int = 1, 72 | drop_prob: int = 0.2, 73 | hist_mask_ratio: float = 0.2, 74 | add_special: bool = False, 75 | # dataset info 76 | scenario: str = "sticker", 77 | img_folder_path: str = "/path/to/data/", 78 | data_path: str = "/path/to/data/", 79 | # training hyperparams 80 | batch_size: int = 128, 81 | micro_batch_size: int = 4, 82 | group_by_length: bool = False, 83 | num_epochs: int = 3, 84 | learning_rate: float = 3e-4, 85 | lr_schedule_type: str = "cosine", 86 | min_learning_rate: float = 1e-6, 87 | seed: int = 123, 88 | logging_steps: int = 25, 89 | save_steps: int = 25, 90 | eval_steps: int = 25, 91 | eval_num: int = 1000, 92 | # lora hyperparams 93 | lora_r: int = 8, 94 | lora_alpha: int = 16, 95 | lora_dropout: float = 0.05, 96 | lora_target_modules: List[str] = [ 97 | "q_proj", 98 | "v_proj", 99 | ], 100 | resume_from_checkpoint: Optional[Union[str, bool]] = None, # either training checkpoint or final adapter 101 | ): 102 | frame = inspect.currentframe() 103 | args, _, _, values = inspect.getargvalues(frame) 104 | params = {arg:values[arg] for arg in args} 105 | 106 | os.makedirs(output_dir, exist_ok=True) 107 | if resume_from_checkpoint: 108 | ckpts = os.listdir(output_dir) 109 | ckpts = [ckpt for ckpt in ckpts if ckpt.startswith("checkpoint") and os.path.isdir(os.path.join(output_dir, ckpt))] 110 | if len(ckpts) == 0: 111 | resume_from_checkpoint = False 112 | params["resume_from_checkpoint"] = False 113 | 114 | logger.info(f"***** Finetuning Llama model with params:") 115 | for k,v in params.items(): 116 | logger.info(f"{k}: {v}") 117 | 118 | save_params_to_txt(os.path.join(output_dir, "args.txt"), **params) 119 | set_random_seed(seed=seed) 120 | 121 | gradient_accumulation_steps = batch_size // micro_batch_size 122 | device_map = "auto" 123 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 124 | ddp = world_size != 1 125 | if ddp: 126 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 127 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 128 | 129 | all_image_paths = np.load(os.path.join(data_path, "all_image_paths.npy"), allow_pickle=True) 130 | 131 | llama_lora_config = LoraConfig( 132 | r=lora_r, 133 | lora_alpha=lora_alpha, 134 | target_modules=lora_target_modules, 135 | lora_dropout=lora_dropout, 136 | bias="none", 137 | task_type="CAUSAL_LM", 138 | ) 139 | 140 | lavit = LaVITforPigeon( 141 | model_path=model_path, 142 | model_dtype=model_dtype, 143 | lora_config=llama_lora_config, 144 | use_xformers=use_xformers, 145 | check_safety=check_safety, 146 | pixel_decoding=pixel_decoding, 147 | mask_type=mask_type, 148 | num_heads=num_heads, 149 | num_layers=num_layers, 150 | drop_prob=drop_prob, 151 | img_folder_path=img_folder_path, 152 | all_image_paths=all_image_paths, 153 | data_path=data_path, 154 | load_in_8bit=load_in_8bit, 155 | device_map=device_map, 156 | ) 157 | 158 | if not ddp and torch.cuda.device_count() > 1: 159 | lavit.llama_model.is_parallelizable = True 160 | lavit.llama_model.model_parallel = True 161 | 162 | 163 | # Load dataset 164 | train_data_path = os.path.join(data_path, "data_ready", "train_ready.npy") 165 | valid_data_path = os.path.join(data_path, "data_ready", "valid_ready.npy") 166 | if os.path.exists(train_data_path) and os.path.exists(valid_data_path): 167 | logger.info(f"Loading processed train dataset and valid dataset from {os.path.join(data_path, 'data_ready')}") 168 | train_data = np.load(train_data_path, allow_pickle=True).item() 169 | valid_data = np.load(valid_data_path, allow_pickle=True).item() 170 | else: 171 | logger.info(f"Processed datasets don't exist. Start data processing...") 172 | if scenario == "sticker": 173 | item_info = np.load(os.path.join(data_path, "mapped_anno_dict.npy"), allow_pickle=True).item() 174 | semantics = np.load(os.path.join(data_path, "sticker_semantics.npy"), allow_pickle=True).tolist() 175 | elif scenario == "movie": 176 | item_info = np.load(os.path.join(data_path, "mapped_movies.npy"), allow_pickle=True).item() 177 | semantics = np.load(os.path.join(data_path, "movie_semantics.npy"), allow_pickle=True).tolist() 178 | 179 | train_seq = np.load(os.path.join(data_path, "train.npy"), allow_pickle=True).item() 180 | valid_seq = np.load(os.path.join(data_path, "valid.npy"), allow_pickle=True).item() 181 | 182 | train_data = data_utils.process_data(scenario, train_seq, item_info, semantics, lavit.llama_tokenizer) 183 | valid_data = data_utils.process_data(scenario, valid_seq, item_info, semantics, lavit.llama_tokenizer) 184 | 185 | if not os.path.exists(os.path.join(data_path, "data_ready")): 186 | os.makedirs(os.path.join(data_path, "data_ready")) 187 | 188 | np.save(train_data_path, np.array(train_data)) 189 | np.save(valid_data_path, np.array(valid_data)) 190 | logger.info(f"Saving processed train dataset and valid dataset into {os.path.join(data_path, 'processed_seq', 'data_ready')}") 191 | 192 | train_dataset = data_utils.PigeonDataset(train_data) 193 | valid_dataset = data_utils.PigeonDataset(valid_data) 194 | 195 | trainer = PigeonTrainer( 196 | model=lavit, 197 | train_dataset=train_dataset, 198 | eval_dataset=valid_dataset, 199 | args=PigeonTrainingArguments( 200 | # model-specific arguments 201 | hist_mask_ratio=hist_mask_ratio, 202 | add_special=add_special, 203 | # training arguments 204 | per_device_train_batch_size=micro_batch_size, 205 | per_device_eval_batch_size=micro_batch_size, 206 | group_by_length=group_by_length, 207 | gradient_accumulation_steps=gradient_accumulation_steps, 208 | warmup_steps=20, 209 | num_train_epochs=num_epochs, 210 | learning_rate=learning_rate, 211 | lr_scheduler_type=lr_schedule_type, 212 | min_lr_ratio=min_learning_rate, 213 | fp16=False, 214 | logging_steps=logging_steps, 215 | optim="adamw_bnb_8bit", 216 | eval_num=eval_num, 217 | evaluation_strategy="steps", 218 | eval_steps=eval_steps, 219 | save_strategy="steps", 220 | save_steps=save_steps, 221 | output_dir=output_dir, 222 | save_total_limit=5, 223 | load_best_model_at_end=True, 224 | ddp_find_unused_parameters=False if ddp else None, 225 | report_to=None, 226 | seed=seed, 227 | log_level="info", 228 | ), 229 | data_collator=data_utils.PigeonCollator(pad_token_id=lavit.llama_tokenizer.pad_token_id, padding_side=lavit.llama_tokenizer.padding_side, return_tensors="pt"), 230 | callbacks = [EarlyStoppingCallback(early_stopping_patience=10)] 231 | ) 232 | 233 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 234 | logger.info(f"BEST MODEL PATH: {trainer.state.best_model_checkpoint}") 235 | 236 | def save_params_to_txt(file_path, **kwargs): 237 | with open(file_path, 'w') as file: 238 | for key, value in kwargs.items(): 239 | file.write(f'{key}: {value}\n') 240 | 241 | def set_random_seed(seed=0): 242 | random.seed(seed) 243 | np.random.seed(seed) 244 | torch.manual_seed(seed) 245 | torch.cuda.manual_seed(seed) 246 | torch.cuda.manual_seed_all(seed) 247 | torch.backends.cudnn.deterministic = True 248 | torch.backends.cudnn.benchmark = False 249 | 250 | if __name__ == "__main__": 251 | fire.Fire(train) 252 | 253 | -------------------------------------------------------------------------------- /Evaluation/select_scores.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 31, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import os" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 32, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "eval_folder = \"/checkpoints/sticker/DPO/eval-test\"\n", 20 | "res_folder = []\n", 21 | "for subf in os.listdir(eval_folder):\n", 22 | " if subf.startswith(\"checkpoint\"):\n", 23 | " res_folder.append(subf)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 33, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "score_name = \"select_scores_dm_scale7.0.npy\"" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 34, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "all_select_scores = {}\n", 42 | "for folder in res_folder:\n", 43 | " path = os.path.join(eval_folder, folder, score_name)\n", 44 | " select_scores = np.load(path, allow_pickle=True).item()\n", 45 | " all_select_scores[folder] = select_scores" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 35, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "example_key = list(sorted(all_select_scores.keys(), key=lambda x:float(x.split(\"-\")[-1][4:])))[0]" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 37, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "'checkpoint-375-scale1.0-mask0.0'" 66 | ] 67 | }, 68 | "execution_count": 37, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "example_key" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 38, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "select_ratio = 0.5" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 39, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "re_select_scores = {}\n", 93 | "for uid in example_scores:\n", 94 | " if uid not in re_select_scores:\n", 95 | " re_select_scores[uid] = {}\n", 96 | "\n", 97 | " for iid in example_scores[uid]:\n", 98 | " if iid not in re_select_scores[uid]:\n", 99 | " re_select_scores[uid][iid] = {}\n", 100 | "\n", 101 | " for gid in example_scores[uid][iid]:\n", 102 | " if gid not in re_select_scores[uid][iid]:\n", 103 | " re_select_scores[uid][iid][gid] = {}\n", 104 | " re_select_scores[uid][iid][gid][\"folder\"] = []\n", 105 | " re_select_scores[uid][iid][gid][\"seq_id\"] = []\n", 106 | " re_select_scores[uid][iid][gid][\"img_id\"] = []\n", 107 | " re_select_scores[uid][iid][gid][\"score\"] = []\n", 108 | " \n", 109 | " for folder in all_select_scores:\n", 110 | " cur_scores = all_select_scores[folder][uid][iid][gid]\n", 111 | " all_seq_id = []\n", 112 | " all_img_id = []\n", 113 | " all_seq_score = []\n", 114 | " for seq_id in cur_scores:\n", 115 | " all_seq_id.append(seq_id)\n", 116 | " tgt_scores = cur_scores[seq_id][\"tgt_score\"]\n", 117 | " hist_scores = cur_scores[seq_id][\"hist_score\"]\n", 118 | " seq_scores = [select_ratio * tgt_score + (1 - select_ratio) * hist_score for tgt_score,hist_score in zip(tgt_scores, hist_scores)]\n", 119 | " max_score = max(seq_scores)\n", 120 | " max_img_idx = seq_scores.index(max_score)\n", 121 | " all_img_id.append(max_img_idx)\n", 122 | " all_seq_score.append(max_score)\n", 123 | " \n", 124 | " max_score = max(all_seq_score)\n", 125 | " max_idx = all_seq_score.index(max_score)\n", 126 | " max_seq_id = all_seq_id[max_idx]\n", 127 | " max_img_id = all_img_id[max_idx]\n", 128 | "\n", 129 | " re_select_scores[uid][iid][gid][\"folder\"].append(folder)\n", 130 | " re_select_scores[uid][iid][gid][\"seq_id\"].append(max_seq_id)\n", 131 | " re_select_scores[uid][iid][gid][\"img_id\"].append(max_img_id)\n", 132 | " re_select_scores[uid][iid][gid][\"score\"].append(max_score)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 40, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "from copy import deepcopy" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 41, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "final_select_scores = deepcopy(re_select_scores)\n", 151 | "\n", 152 | "for uid in re_select_scores:\n", 153 | " for iid in re_select_scores[uid]:\n", 154 | " for gid in re_select_scores[uid][iid]:\n", 155 | " scores = re_select_scores[uid][iid][gid][\"score\"]\n", 156 | " best_score = max(scores)\n", 157 | " best_idx = scores.index(best_score)\n", 158 | " best_folder = re_select_scores[uid][iid][gid][\"folder\"][best_idx]\n", 159 | " best_seq_id = re_select_scores[uid][iid][gid][\"seq_id\"][best_idx]\n", 160 | " best_img_id = re_select_scores[uid][iid][gid][\"img_id\"][best_idx]\n", 161 | "\n", 162 | " final_select_scores[uid][iid][gid][\"folder\"] = best_folder\n", 163 | " final_select_scores[uid][iid][gid][\"seq_id\"] = best_seq_id\n", 164 | " final_select_scores[uid][iid][gid][\"img_id\"] = best_img_id\n", 165 | " final_select_scores[uid][iid][gid][\"score\"] = best_score" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 42, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "np.save(os.path.join(eval_folder, \"final_select_scores.npy\"), np.array(final_select_scores))" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 43, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "all_res = {}\n", 184 | "res_name = \"all_res_dm_scale7.0.npy\"\n", 185 | "for folder in res_folder:\n", 186 | " path = os.path.join(eval_folder, folder, res_name)\n", 187 | " res = np.load(path, allow_pickle=True).item()\n", 188 | " all_res[folder] = res" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 44, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "import shutil\n", 198 | "final_folder = example_key.split(\"-\")\n", 199 | "final_folder[-1] = \"mask2.0\" # mask2.0 denotes the selected images\n", 200 | "final_folder = \"-\".join(final_folder)\n", 201 | "\n", 202 | "final_res = deepcopy(all_res[example_key])\n", 203 | "for uid in final_select_scores:\n", 204 | " for iid in final_select_scores[uid]:\n", 205 | " for gid in final_select_scores[uid][iid]:\n", 206 | " folder = final_select_scores[uid][iid][gid][\"folder\"]\n", 207 | " seq_id = final_select_scores[uid][iid][gid][\"seq_id\"]\n", 208 | " img_id = final_select_scores[uid][iid][gid][\"img_id\"]\n", 209 | " score = final_select_scores[uid][iid][gid][\"score\"]\n", 210 | "\n", 211 | " seqs_info = {}\n", 212 | " seqs_info[0] = deepcopy(all_res[folder][uid][iid][gid][\"seqs\"][seq_id])\n", 213 | " seqs_info[0][\"src_images\"] = seqs_info[0][\"images\"][img_id:img_id+1]\n", 214 | " seqs_info[0][\"images\"] = []\n", 215 | "\n", 216 | " for src_img_path in seqs_info[0][\"src_images\"]:\n", 217 | " src_img_folder = \"/\".join(src_img_path.split(\"/\")[:-2])\n", 218 | " img_path_list = [final_folder] + src_img_path.split(\"/\")[1:]\n", 219 | " tgt_img_folder = \"/\".join(img_path_list[:-2])\n", 220 | " os.makedirs(os.path.join(eval_folder, tgt_img_folder), exist_ok=True)\n", 221 | "\n", 222 | " for item in os.listdir(os.path.join(eval_folder, src_img_folder)):\n", 223 | " src_path = os.path.join(eval_folder, src_img_folder, item)\n", 224 | " if os.path.isfile(src_path):\n", 225 | " shutil.copy(src_path, os.path.join(eval_folder, tgt_img_folder))\n", 226 | "\n", 227 | " tgt_img_folder = \"/\".join(img_path_list[:-2] + [\"0\"])\n", 228 | " os.makedirs(os.path.join(eval_folder, tgt_img_folder), exist_ok=True)\n", 229 | " tgt_img_path = \"/\".join(img_path_list[:-2] + [\"0\", \"0.jpg\"])\n", 230 | " shutil.copy(os.path.join(eval_folder, src_img_path), os.path.join(eval_folder, tgt_img_path)) \n", 231 | "\n", 232 | " seqs_info[0][\"images\"].append(tgt_img_path)\n", 233 | "\n", 234 | " final_res[uid][iid][gid][\"seqs\"] = seqs_info" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 45, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "np.save(os.path.join(eval_folder, final_folder, res_name), np.array(final_res))" 244 | ] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "pigeon", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.8.13" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 2 268 | } 269 | -------------------------------------------------------------------------------- /Pigeon/data_utils_dpo.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import re 5 | from fileinput import filename 6 | 7 | import numpy as np 8 | import scipy.sparse as sp 9 | from dataclasses import dataclass 10 | from typing import Optional, Union, Mapping, List, Dict, Any 11 | import torch 12 | import torch.utils.data as data 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms 15 | from PIL import Image 16 | from transformers import LlamaTokenizer 17 | from transformers.utils import PaddingStrategy 18 | 19 | class sticker_template: 20 | instruction = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." 21 | task_description = "Instruction: You are a helpful personalized assistant. You will be provided with a list of stickers that the user likes to analyze the user's preference over stickers. Based on this analysis, please design a personalized sticker that aligns with target sticker info and the user's sticker tastes." 22 | prompt_start = "Input: " 23 | history_start = "The user likes the following stickers: " 24 | target_start = "The target sticker info: a sticker with the emotion of <emo>. " 25 | prompt_end = "Response: " 26 | 27 | class movie_template: 28 | instruction = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." 29 | task_description = "Instruction: You are a helpful personalized assistant. You will be provided with a list of movies that the user likes, along with the movie posters, to analyze the user's preference over posters. Based on this analysis, please design a personalized poster that aligns with target movie and the user's poster tastes." 30 | prompt_start = "Input: " 31 | history_start = "The user likes the following movies: " 32 | target_start = "The target movie is titled '<title>'. " 33 | prompt_end = "Response: " 34 | 35 | SEP = "\n" 36 | IMG_TOKEN = "<img>" 37 | EMB_TOKEN = "<emb>" 38 | 39 | class ImagePathDataset(Dataset): 40 | def __init__(self, folder_path, paths, trans=None): 41 | self.folder_path = folder_path 42 | self.paths = paths 43 | self.trans = trans 44 | 45 | def __len__(self): 46 | return len(self.paths) 47 | 48 | def __getitem__(self, idx): 49 | path = os.path.join(self.folder_path, self.paths[idx]) 50 | img = Image.open(path).convert('RGB') 51 | if self.trans is not None: 52 | img = self.trans(img) 53 | 54 | return img.to(memory_format=torch.contiguous_format) 55 | 56 | class DiffusionImageDataset(Dataset): 57 | def __init__(self, start, end): 58 | self.data = list(range(start, end)) 59 | 60 | def __len__(self): 61 | return len(self.data) 62 | 63 | def __getitem__(self, idx): 64 | return self.data[idx] 65 | 66 | class DPODataset(Dataset): 67 | def __init__(self, data): 68 | self.data = data 69 | 70 | def __len__(self): 71 | return len(self.data["uids"]) 72 | 73 | def __getitem__(self, idx): 74 | text_input_ids = self.data["text_input_ids"][idx] 75 | text_attn_mask = self.data["text_attn_mask"][idx] 76 | chosen_tokens = self.data["chosen_tokens"][idx] 77 | rejected_tokens = self.data["rejected_tokens"][idx] 78 | image = self.data["image"][idx] 79 | uids = self.data["uids"][idx] 80 | genres = self.data["genres"][idx] 81 | 82 | return { 83 | "text_input_ids": text_input_ids, 84 | "text_attn_mask": text_attn_mask, 85 | "chosen_tokens": chosen_tokens, 86 | "rejected_tokens": rejected_tokens, 87 | "image": image, 88 | "uids": uids, 89 | "genres": genres 90 | } 91 | 92 | def shuffle(self, seed=None): 93 | if seed is not None: 94 | torch.manual_seed(seed) 95 | data_len = len(self.data["uids"]) 96 | indices = torch.randperm(data_len).tolist() 97 | self.data = {key:[self.data[key][i] for i in indices] for key in self.data} 98 | 99 | return self 100 | 101 | def select(self, indices): 102 | selected_data = {key:[self.data[key][i] for i in indices] for key in self.data} 103 | return DPODataset(selected_data) 104 | 105 | @dataclass 106 | class DPOCollator: 107 | pad_token_id: int = 2 108 | padding_side: Optional[str] = "left" 109 | return_tensors: str = "pt" 110 | 111 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 112 | # If we have a list of dicts, let's convert it in a dict of lists 113 | # We do this to allow using this method as a collate_fn function in PyTorch Dataloader 114 | if isinstance(features, (list, tuple)) and isinstance(features[0], Mapping): 115 | features = {key: [example[key] for example in features] for key in features[0].keys()} 116 | 117 | max_length = max(len(input_ids) for input_ids in features["text_input_ids"]) 118 | bsz = len(features["text_input_ids"]) 119 | for i in range(bsz): 120 | pad_num = max_length - len(features["text_input_ids"][i]) 121 | 122 | if self.padding_side == "left": 123 | features["text_input_ids"][i] = [self.pad_token_id] * pad_num + features["text_input_ids"][i] 124 | features["text_attn_mask"][i] = [0] * pad_num + features["text_attn_mask"][i] 125 | else: 126 | features["text_input_ids"][i] = features["text_input_ids"][i] + [self.pad_token_id] * pad_num 127 | features["text_attn_mask"][i] = features["text_attn_mask"][i] + [0] * pad_num 128 | 129 | if self.return_tensors == "pt": 130 | for key in features.keys(): 131 | if key == "chosen_tokens" or key == "rejected_tokens": 132 | # keep these output tokens as list 133 | continue 134 | features[key] = torch.tensor(features[key], dtype=torch.long) 135 | # features = {key: torch.tensor(features[key], dtype=torch.long) for key in features.keys()} 136 | 137 | return features 138 | 139 | def process_dpo_data( 140 | scenario: str="sticker", 141 | data: dict=None, 142 | preference_data: dict=None, 143 | tokenizer: LlamaTokenizer=None, 144 | item_info: dict=None, 145 | semantics: list=None, 146 | ): 147 | text_input_ids = [] 148 | text_attn_mask = [] 149 | all_chosen_tokens = [] 150 | all_rejected_tokens = [] 151 | image = [] 152 | uids = [] 153 | genres = [] 154 | for uid in data: 155 | for genre in data[uid]: 156 | hist_seqs = data[uid][genre] 157 | for hist_seq in hist_seqs: 158 | iid = hist_seq[-1] 159 | chosen_tokens = preference_data[uid][iid][genre]["chosen"] 160 | rejected_tokens = preference_data[uid][iid][genre]["rejected"] 161 | 162 | if scenario == "sticker": 163 | text_prompt = generate_prompt_for_stickers(sticker_template, hist_seq, item_info, semantics) 164 | elif scenario == "movie": 165 | text_prompt = generate_prompt_for_movies(movie_template, hist_seq, item_info, semantics) 166 | 167 | prompt_tokens = tokenizer( 168 | text_prompt, padding="longest", return_tensors="pt", add_special_tokens=False 169 | ) 170 | 171 | all_chosen_tokens.append(chosen_tokens) 172 | all_rejected_tokens.append(rejected_tokens) 173 | text_input_ids.append(prompt_tokens.input_ids.squeeze(0).tolist()) 174 | text_attn_mask.append(prompt_tokens.attention_mask.squeeze(0).tolist()) 175 | # the last image is the target image, to be the input and output at the same time 176 | image.append(hist_seq) 177 | uids.append(uid) 178 | genres.append(genre) 179 | 180 | return {"text_input_ids": text_input_ids, "text_attn_mask": text_attn_mask, 181 | "chosen_tokens": all_chosen_tokens, "rejected_tokens": all_rejected_tokens, 182 | "image": image, "uids": uids, "genres": genres} 183 | 184 | def generate_prompt_for_movies( 185 | template, 186 | hist_seq: dict=None, 187 | movies_info: dict=None, 188 | semantics: list=None, 189 | ): 190 | hist_prompt = template.history_start 191 | for iid in hist_seq[:-1]: 192 | title = movies_info[iid]["title"] 193 | try: 194 | title = re.findall(r'^(.*) \(\d+\) *$', title)[0] 195 | except: 196 | title = title 197 | 198 | hist_prompt += SEP 199 | hist_prompt += title 200 | hist_prompt += " " 201 | hist_prompt += IMG_TOKEN 202 | 203 | tgt_iid = hist_seq[-1] 204 | target_title = movies_info[tgt_iid]["title"] 205 | try: 206 | target_title = re.findall(r'^(.*) \(\d+\) *$', target_title)[0] 207 | except: 208 | target_title = target_title 209 | 210 | target_prompt = template.target_start.("<title>", target_title) 211 | target_prompt = target_prompt + semantics[tgt_iid] + " " + EMB_TOKEN 212 | 213 | text_prompt = template.instruction + SEP + template.task_description + SEP + template.prompt_start + hist_prompt + SEP + target_prompt + SEP + template.prompt_end 214 | text_prompt += IMG_TOKEN 215 | 216 | return text_prompt 217 | 218 | def generate_prompt_for_stickers( 219 | template, 220 | hist_seq: dict=None, 221 | anno_dict: dict=None, 222 | semantics: list=None, 223 | ): 224 | hist_prompt = template.history_start 225 | for iid in hist_seq[:-1]: 226 | hist_prompt += SEP 227 | hist_prompt += IMG_TOKEN 228 | 229 | tgt_iid = hist_seq[-1] 230 | emo = anno_dict[tgt_iid]["emo"] 231 | 232 | target_prompt = template.target_start.replace("<emo>", emo) 233 | target_prompt = target_prompt + semantics[tgt_iid] + " " + EMB_TOKEN 234 | 235 | text_prompt = template.instruction + SEP + template.task_description + SEP + template.prompt_start + hist_prompt + SEP + target_prompt + SEP + template.prompt_end 236 | text_prompt += IMG_TOKEN 237 | 238 | return text_prompt 239 | -------------------------------------------------------------------------------- /Evaluation/select_scores_dpo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 31, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import os" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 32, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "mode = \"train\"\n", 20 | "eval_folder = f\"/checkpoints/sticker/SFT/eval-{mode}\"\n", 21 | "res_folder = []\n", 22 | "for subf in os.listdir(eval_folder):\n", 23 | " if subf.startswith(\"checkpoint\"):\n", 24 | " res_folder.append(subf)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 33, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "score_name = \"select_scores_dm_scale7.0.npy\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 34, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "all_select_scores = {}\n", 43 | "for folder in res_folder:\n", 44 | " path = os.path.join(eval_folder, folder, score_name)\n", 45 | " select_scores = np.load(path, allow_pickle=True).item()\n", 46 | " all_select_scores[folder] = select_scores" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 35, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "example_key = list(sorted(all_select_scores.keys(), key=lambda x:float(x.split(\"-\")[-1][4:])))[0]" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 37, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "'checkpoint-375-scale1.0-mask0.0'" 67 | ] 68 | }, 69 | "execution_count": 37, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "example_key" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 39, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "re_select_scores = {}\n", 85 | "for uid in example_scores:\n", 86 | " if uid not in re_select_scores:\n", 87 | " re_select_scores[uid] = {}\n", 88 | "\n", 89 | " for iid in example_scores[uid]:\n", 90 | " if iid not in re_select_scores[uid]:\n", 91 | " re_select_scores[uid][iid] = {}\n", 92 | "\n", 93 | " for gid in example_scores[uid][iid]:\n", 94 | " if gid not in re_select_scores[uid][iid]:\n", 95 | " re_select_scores[uid][iid][gid] = {}\n", 96 | " re_select_scores[uid][iid][gid][\"folder\"] = []\n", 97 | " re_select_scores[uid][iid][gid][\"seq_id\"] = []\n", 98 | " re_select_scores[uid][iid][gid][\"img_id\"] = []\n", 99 | " re_select_scores[uid][iid][gid][\"score\"] = []\n", 100 | " \n", 101 | " for folder in all_select_scores:\n", 102 | " cur_scores = all_select_scores[folder][uid][iid][gid]\n", 103 | " all_seq_id = []\n", 104 | " all_img_id = []\n", 105 | " all_seq_score = []\n", 106 | " for seq_id in cur_scores:\n", 107 | " all_seq_id.append(seq_id)\n", 108 | " seq_scores = cur_scores[seq_id][\"hist_score\"]\n", 109 | " max_score = max(seq_scores)\n", 110 | " max_img_idx = seq_scores.index(max_score)\n", 111 | " all_img_id.append(max_img_idx)\n", 112 | " all_seq_score.append(max_score)\n", 113 | " \n", 114 | " max_score = max(all_seq_score)\n", 115 | " max_idx = all_seq_score.index(max_score)\n", 116 | " max_seq_id = all_seq_id[max_idx]\n", 117 | " max_img_id = all_img_id[max_idx]\n", 118 | "\n", 119 | " re_select_scores[uid][iid][gid][\"folder\"].append(folder)\n", 120 | " re_select_scores[uid][iid][gid][\"seq_id\"].append(max_seq_id)\n", 121 | " re_select_scores[uid][iid][gid][\"img_id\"].append(max_img_id)\n", 122 | " re_select_scores[uid][iid][gid][\"score\"].append(max_score)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 40, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "from copy import deepcopy" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 41, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "chosen_scores = deepcopy(re_select_scores)\n", 141 | "rejected_scores = deepcopy(re_select_scores)\n", 142 | "\n", 143 | "for uid in re_select_scores:\n", 144 | " for iid in re_select_scores[uid]:\n", 145 | " for gid in re_select_scores[uid][iid]:\n", 146 | " scores = re_select_scores[uid][iid][gid][\"score\"]\n", 147 | " best_score = max(scores)\n", 148 | " best_idx = scores.index(best_score)\n", 149 | " best_folder = re_select_scores[uid][iid][gid][\"folder\"][best_idx]\n", 150 | " best_seq_id = re_select_scores[uid][iid][gid][\"seq_id\"][best_idx]\n", 151 | " best_img_id = re_select_scores[uid][iid][gid][\"img_id\"][best_idx]\n", 152 | "\n", 153 | " chosen_scores[uid][iid][gid][\"folder\"] = best_folder\n", 154 | " chosen_scores[uid][iid][gid][\"seq_id\"] = best_seq_id\n", 155 | " chosen_scores[uid][iid][gid][\"img_id\"] = best_img_id\n", 156 | " chosen_scores[uid][iid][gid][\"score\"] = best_score\n", 157 | "\n", 158 | " worst_score = min(scores)\n", 159 | " worst_idx = scores.index(worst_score)\n", 160 | " rejected_scores[uid][iid][gid][\"folder\"] = re_select_scores[uid][iid][gid][\"folder\"][worst_idx]\n", 161 | " rejected_scores[uid][iid][gid][\"seq_id\"] = re_select_scores[uid][iid][gid][\"seq_id\"][worst_idx]\n", 162 | " rejected_scores[uid][iid][gid][\"img_id\"] = re_select_scores[uid][iid][gid][\"img_id\"][worst_idx]\n", 163 | " rejected_scores[uid][iid][gid][\"score\"] = worst_score" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 42, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "np.save(os.path.join(eval_folder, \"chosen_scores.npy\"), np.array(chosen_scores))\n", 173 | "np.save(os.path.join(eval_folder, \"rejected_scores.npy\"), np.array(rejected_scores))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 43, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "all_res = {}\n", 183 | "res_name = \"all_res_dm_scale7.0.npy\"\n", 184 | "for folder in res_folder:\n", 185 | " path = os.path.join(eval_folder, folder, res_name)\n", 186 | " res = np.load(path, allow_pickle=True).item()\n", 187 | " all_res[folder] = res" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "import torch\n", 197 | "\n", 198 | "data_seq = {}\n", 199 | "preference_data = {}\n", 200 | "seq_num = 0\n", 201 | "for uid in chosen_scores:\n", 202 | " if uid not in data_seq:\n", 203 | " data_seq[uid] = {}\n", 204 | " preference_data[uid] = {}\n", 205 | "\n", 206 | " for iid in chosen_scores[uid]:\n", 207 | " if iid not in preference_data[uid]:\n", 208 | " preference_data[uid][iid] = {}\n", 209 | "\n", 210 | " for gid in chosen_scores[uid][iid]:\n", 211 | " if gid not in preference_data[uid][iid]:\n", 212 | " preference_data[uid][iid][gid] = {}\n", 213 | " \n", 214 | " if gid not in data_seq[uid]:\n", 215 | " data_seq[uid][gid] = []\n", 216 | "\n", 217 | " chosen_folder = chosen_scores[uid][iid][gid][\"folder\"]\n", 218 | " chosen_seq_id = chosen_scores[uid][iid][gid][\"seq_id\"]\n", 219 | "\n", 220 | " history = all_res[chosen_folder][uid][iid][gid][\"history\"].tolist()\n", 221 | " chosen = all_res[chosen_folder][uid][iid][gid][\"seqs\"][chosen_seq_id][\"tokens\"]\n", 222 | " non_pad_idx = torch.nonzero(chosen != 2, as_tuple=False)[-1].item()\n", 223 | " chosen = chosen[:non_pad_idx + 1]\n", 224 | "\n", 225 | " rejected_folder = rejected_scores[uid][iid][gid][\"folder\"]\n", 226 | " rejected_seq_id = rejected_scores[uid][iid][gid][\"seq_id\"]\n", 227 | " rejected = all_res[rejected_folder][uid][iid][gid][\"seqs\"][rejected_seq_id][\"tokens\"]\n", 228 | " non_pad_idx = torch.nonzero(rejected != 2, as_tuple=False)[-1].item()\n", 229 | " rejected = rejected[:non_pad_idx + 1]\n", 230 | " \n", 231 | " seq = history + [iid]\n", 232 | " data_seq[uid][gid].append(seq)\n", 233 | " preference_data[uid][iid][gid][\"chosen\"] = chosen\n", 234 | " preference_data[uid][iid][gid][\"rejected\"] = rejected\n", 235 | "\n", 236 | " seq_num += 1" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 45, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "np.save(f\"{data_path}/dpo/{mode}.npy\", np.array(data_seq))\n", 246 | "np.save(f\"{data_path}/dpo/{mode}_preference_data.npy\", np.array(preference_data))" 247 | ] 248 | } 249 | ], 250 | "metadata": { 251 | "kernelspec": { 252 | "display_name": "pigeon", 253 | "language": "python", 254 | "name": "python3" 255 | }, 256 | "language_info": { 257 | "codemirror_mode": { 258 | "name": "ipython", 259 | "version": 3 260 | }, 261 | "file_extension": ".py", 262 | "mimetype": "text/x-python", 263 | "name": "python", 264 | "nbconvert_exporter": "python", 265 | "pygments_lexer": "ipython3", 266 | "version": "3.8.13" 267 | } 268 | }, 269 | "nbformat": 4, 270 | "nbformat_minor": 2 271 | } 272 | -------------------------------------------------------------------------------- /Evaluation/cal_scores.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | import pathlib 4 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 5 | import logging 6 | import fire 7 | import inspect 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader 12 | from torchvision import transforms 13 | from PIL import Image 14 | import random 15 | import open_clip 16 | 17 | try: 18 | from tqdm import tqdm 19 | except ImportError: 20 | # If tqdm is not available, provide a mock version of it 21 | def tqdm(x): 22 | return x 23 | 24 | import eval_utils 25 | 26 | def setup_logger(): 27 | logger = logging.getLogger() 28 | logger.setLevel(logging.INFO) 29 | 30 | console_handler = logging.StreamHandler() 31 | console_handler.setLevel(logging.INFO) 32 | 33 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 34 | console_handler.setFormatter(formatter) 35 | 36 | if not logger.handlers: 37 | logger.addHandler(console_handler) 38 | 39 | setup_logger() 40 | logger = logging.getLogger(__name__) 41 | 42 | def main( 43 | output_dir: str = "", 44 | scenario: str = "sticker", 45 | data_path: str = "", 46 | img_folder_path: str = "", 47 | batch_size: int = 64, 48 | mode: str = "test", 49 | ckpt: int = None, 50 | height: int = 512, 51 | width: int = 512, 52 | scale_for_llm: float = 1.0, 53 | scale_for_dm: float = 7.0, 54 | seed: int = 123, 55 | ): 56 | frame = inspect.currentframe() 57 | args, _, _, values = inspect.getargvalues(frame) 58 | params = {arg:values[arg] for arg in args} 59 | 60 | set_random_seed(seed=seed) 61 | 62 | logger.info(f"##### Calculate preference score and semantic alignment score #####") 63 | for k,v in params.items(): 64 | logger.info(f"{k}: {v}") 65 | 66 | if torch.cuda.is_available(): 67 | device = torch.device('cuda') 68 | logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") 69 | else: 70 | device = torch.device('cpu') 71 | logger.info("Using CPU") 72 | 73 | if mode == "valid": 74 | hist_path = os.path.join(data_path, "data_ready", "valid_hist_embeds.npy") 75 | else: 76 | hist_path = os.path.join(data_path, "data_ready", "test_hist_embeds.npy") 77 | if os.path.exists(hist_path): 78 | logger.info(f"Loading history embeds for {mode} from {hist_path}...") 79 | history = np.load(hist_path, allow_pickle=True).item() 80 | else: 81 | all_clip_feats_path = os.path.join(data_path, "all_clip_feats.npy") 82 | if os.path.exists(all_clip_feats_path): 83 | logger.info(f"Loading clip feats for all images from {all_clip_feats_path}...") 84 | all_clip_feats = np.load(all_clip_feats_path, allow_pickle=True) 85 | else: 86 | logger.info("Extracting clip feats for all images...") 87 | all_clip_feats = eval_utils.extract_clip_feats( 88 | args.img_folder_path, all_image_paths, args.clip_batch_size, 89 | num_workers, device 90 | ) 91 | np.save(all_clip_feats_path, np.array(all_clip_feats)) 92 | logger.info(f"Successfully saved clip feats for all images into {all_clip_feats_path}.") 93 | 94 | all_clip_feats = torch.tensor(all_clip_feats) 95 | logger.info(f"clip feats shape: {all_clip_feats.shape}") 96 | logger.info("Processing history embeds for evaluation...") 97 | history = eval_utils.process_hist_embs(all_res, all_clip_feats) 98 | np.save(hist_path, np.array(history)) 99 | logger.info(f"Successfully saved ori history embeds for {mode} set into {hist_path}.") 100 | 101 | all_image_paths = np.load(os.path.join(data_path, "all_image_paths.npy"), allow_pickle=True) 102 | 103 | if scenario == "sticker": 104 | semantics = np.load(os.path.join(data_path, "sticker_semantics.npy"), allow_pickle=True).tolist() 105 | elif scenario == "movie": 106 | semantics = np.load(os.path.join(data_path, "movie_semantics.npy"), allow_pickle=True).tolist() 107 | 108 | default_shape = (height, width) 109 | resize = transforms.Resize(default_shape, interpolation=transforms.InterpolationMode.BILINEAR) 110 | _, _, clip_trans = open_clip.create_model_and_transforms('ViT-H-14', pretrained="laion2b-s32b-b79K") 111 | 112 | mode_name = "eval" if mode == "valid" else "eval-test" 113 | eval_path = os.path.join(output_dir, mode_name) 114 | 115 | mask_ratios = ["0.0", "0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9", "1.0"] 116 | 117 | for i,mask_ratio in enumerate(mask_ratios): 118 | res_folder = os.path.join(eval_path, f"checkpoint-{ckpt}-scale{scale_for_llm}-mask{mask_ratio}") 119 | res_path = os.path.join(res_folder, f"all_res_dm_scale{scale_for_dm}.npy") 120 | logger.info(f"##### CURRENT RESULTS: {res_path} #####") 121 | all_res = np.load(res_path, allow_pickle=True).item() 122 | 123 | if i == 0: 124 | text4eval = [] 125 | hist4clip = [] 126 | 127 | seq_num, img_num = None, None 128 | for uid in all_res: 129 | for tgt_iid in all_res[uid]: 130 | for genre_id in all_res[uid][tgt_iid]: 131 | text4eval.append(semantics[tgt_iid]) 132 | for i in range(5): # hist_num=5 133 | hist4clip.append(history[uid][tgt_iid][genre_id][i]) 134 | 135 | if seq_num is None: 136 | seq_num = len(all_res[uid][tgt_iid][genre_id]["seqs"]) 137 | img_num = len(all_res[uid][tgt_iid][genre_id]["seqs"][0]["images"]) 138 | 139 | cur_score_save_path = os.path.join(res_folder, f"select_scores_dm_scale{scale_for_dm}.npy") 140 | if os.path.exists(cur_score_save_path): 141 | logger.info(f"The select scores for {res_folder} has already been calculated. Skip.") 142 | continue 143 | 144 | select_scores = {} 145 | for i in range(seq_num): 146 | for j in range(img_num): 147 | gen4clip = [] 148 | for uid in all_res: 149 | for tgt_iid in all_res[uid]: 150 | for genre_id in all_res[uid][tgt_iid]: 151 | img_path = all_res[uid][tgt_iid][genre_id]["seqs"][i]["images"][j] 152 | img_path = os.path.join(eval_path, img_path) 153 | im = Image.open(img_path).convert('RGB') 154 | im = resize(im) 155 | gen4clip.append(clip_trans(im)) 156 | 157 | gen_dataset = EvalDataset(gen4clip) 158 | txt_dataset = EvalDataset(text4eval) 159 | 160 | _, clip_grd_scores = eval_utils.calculate_clip_score_given_data( 161 | gen_dataset, 162 | txt_dataset, 163 | batch_size=batch_size, 164 | device=device, 165 | ) 166 | 167 | gen4clip_hist = [ele for ele in gen4clip for _ in range(5)] 168 | gen4personal_sim = {} 169 | gen4personal_sim["gen"] = gen4clip_hist 170 | gen4personal_sim["hist"] = hist4clip 171 | 172 | gen_dataset_personal_sim = PersonalSimDataset(gen4personal_sim) 173 | _, clip_hist_scores = eval_utils.evaluate_personalization_given_data_sim( 174 | gen4eval=gen_dataset_personal_sim, 175 | batch_size=batch_size, 176 | device=device, 177 | ) 178 | clip_hist_scores = clip_hist_scores.view(-1, 5) 179 | clip_hist_scores = torch.mean(clip_hist_scores, dim=1) 180 | 181 | del gen4clip 182 | del gen4clip_hist 183 | 184 | cur_idx = 0 185 | for uid in all_res: 186 | if uid not in select_scores: 187 | select_scores[uid] = {} 188 | 189 | for tgt_iid in all_res[uid]: 190 | if tgt_iid not in select_scores[uid]: 191 | select_scores[uid][tgt_iid] = {} 192 | 193 | for genre_id in all_res[uid][tgt_iid]: 194 | if genre_id not in select_scores[uid][tgt_iid]: 195 | select_scores[uid][tgt_iid][genre_id] = {} 196 | 197 | if i not in select_scores[uid][tgt_iid][genre_id]: 198 | select_scores[uid][tgt_iid][genre_id][i] = {} 199 | select_scores[uid][tgt_iid][genre_id][i]["tgt_score"] = [] 200 | select_scores[uid][tgt_iid][genre_id][i]["hist_score"] = [] 201 | 202 | select_scores[uid][tgt_iid][genre_id][i]["tgt_score"].append(clip_grd_scores[cur_idx].item()) 203 | select_scores[uid][tgt_iid][genre_id][i]["hist_score"].append(clip_hist_scores[cur_idx].item()) 204 | cur_idx += 1 205 | 206 | np.save(cur_score_save_path, np.array(select_scores)) 207 | logger.info(f"Select score for {res_folder} has been calculated successfully and saved in {cur_score_save_path}.") 208 | 209 | logger.info(f"Select scores for all the mask_ratio {mask_ratio} has been calculated successfully.") 210 | 211 | class EvalDataset(Dataset): 212 | def __init__(self, data): 213 | self.data = data 214 | 215 | def __len__(self): 216 | return len(self.data) 217 | 218 | def __getitem__(self, idx): 219 | return self.data[idx] 220 | 221 | class PersonalSimDataset(Dataset): 222 | def __init__(self, data): 223 | self.data = data 224 | 225 | def __len__(self): 226 | return len(self.data["gen"]) 227 | 228 | def __getitem__(self, index): 229 | gen = self.data["gen"][index] 230 | hist = self.data["hist"][index] 231 | 232 | return gen, hist 233 | 234 | def save_params_to_txt(file_path, **kwargs): 235 | with open(file_path, 'w') as file: 236 | for key, value in kwargs.items(): 237 | file.write(f'{key}: {value}\n') 238 | 239 | def set_random_seed(seed=0): 240 | random.seed(seed) 241 | np.random.seed(seed) 242 | torch.manual_seed(seed) 243 | torch.cuda.manual_seed(seed) 244 | torch.cuda.manual_seed_all(seed) 245 | torch.backends.cudnn.deterministic = True 246 | torch.backends.cudnn.benchmark = False 247 | 248 | if __name__ == "__main__": 249 | fire.Fire(main) 250 | -------------------------------------------------------------------------------- /Pigeon/finetune_dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List 4 | import logging 5 | import random 6 | 7 | import inspect 8 | import numpy as np 9 | import fire 10 | import torch 11 | import safetensors 12 | import transformers 13 | from datasets import load_dataset, concatenate_datasets 14 | from transformers import EarlyStoppingCallback 15 | from typing import Optional, Union 16 | """ 17 | Unused imports:` 18 | import torch.nn as nn 19 | import bitsandbytes as bnb 20 | """ 21 | 22 | from peft import ( 23 | LoraConfig, 24 | get_peft_model, 25 | get_peft_model_state_dict, 26 | prepare_model_for_int8_training, 27 | set_peft_model_state_dict, 28 | ) 29 | from transformers import LlamaForCausalLM, LlamaTokenizer 30 | from transformers.data.data_collator import default_data_collator 31 | from models.lavit_utils import get_rank 32 | from models.lavit_for_pigeon_dpo import LaVITforPigeonDPO 33 | from models.trainer.dpo_trainer import DPOTrainer, PigeonDPOConfig 34 | from models.lavit_utils import convert_weights_to_bf16, convert_weights_to_fp16 35 | import data_utils_dpo 36 | 37 | class InfoFilter(logging.Filter): 38 | def filter(self, record): 39 | return record.levelno <= logging.INFO 40 | 41 | def setup_logger(): 42 | logger = logging.getLogger() 43 | logger.setLevel(logging.INFO) 44 | 45 | console_handler = logging.StreamHandler() 46 | console_handler.setLevel(logging.INFO) 47 | 48 | info_filter = InfoFilter() 49 | console_handler.addFilter(info_filter) 50 | 51 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 52 | console_handler.setFormatter(formatter) 53 | 54 | if not logger.handlers: 55 | logger.addHandler(console_handler) 56 | 57 | setup_logger() 58 | logger = logging.getLogger(__name__) 59 | 60 | def train( 61 | model_path: str = "/path/to/pre-trained/LaVIT/", 62 | model_dtype: str = "bf16", 63 | output_dir: str = "/path/to/output/", 64 | pre_ckpt: str = None, 65 | # model-specific hyperparams 66 | use_xformers: bool = True, 67 | load_in_8bit: bool = True, 68 | check_safety: bool = False, 69 | pixel_decoding: str = "highres", 70 | # mask generator hyperparams 71 | mask_type: str = "random", 72 | num_heads: int = 4, 73 | num_layers: int = 1, 74 | drop_prob: int = 0.2, 75 | hist_mask_ratio: float = 0.2, 76 | add_special: bool = False, 77 | # dataset info 78 | scenario: str = "sticker", 79 | img_folder_path: str = "/path/to/data/", 80 | data_path: str = "/path/to/data/", 81 | # training hyperparams 82 | mode: str = "dpo", 83 | batch_size: int = 128, 84 | micro_batch_size: int = 4, 85 | group_by_length: bool = False, 86 | num_epochs: int = 3, 87 | learning_rate: float = 3e-4, 88 | lr_schedule_type: str = "cosine", 89 | min_learning_rate: float = 1e-6, 90 | seed: int = 123, 91 | logging_steps: int = 25, 92 | save_steps: int = 25, 93 | eval_steps: int = 25, 94 | eval_num: int = 1000, 95 | # lora hyperparams 96 | lora_r: int = 8, 97 | lora_alpha: int = 16, 98 | lora_dropout: float = 0.05, 99 | lora_target_modules: List[str] = [ 100 | "q_proj", 101 | "v_proj", 102 | ], 103 | resume_from_checkpoint: Optional[Union[str, bool]] = None, # either training checkpoint or final adapter 104 | ): 105 | frame = inspect.currentframe() 106 | args, _, _, values = inspect.getargvalues(frame) 107 | params = {arg:values[arg] for arg in args} 108 | 109 | os.makedirs(output_dir, exist_ok=True) 110 | if resume_from_checkpoint: 111 | ckpts = os.listdir(output_dir) 112 | ckpts = [ckpt for ckpt in ckpts if ckpt.startswith("checkpoint") and os.path.isdir(os.path.join(output_dir, ckpt))] 113 | if len(ckpts) == 0: 114 | resume_from_checkpoint = False 115 | params["resume_from_checkpoint"] = False 116 | 117 | logger.info(f"***** Finetuning Llama model with params:") 118 | for k,v in params.items(): 119 | logger.info(f"{k}: {v}") 120 | 121 | save_params_to_txt(os.path.join(output_dir, "args.txt"), **params) 122 | set_random_seed(seed=seed) 123 | 124 | gradient_accumulation_steps = batch_size // micro_batch_size 125 | device_map = "auto" 126 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 127 | ddp = world_size != 1 128 | if ddp: 129 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 130 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 131 | 132 | all_image_paths = np.load(os.path.join(data_path, "all_image_paths.npy"), allow_pickle=True) 133 | 134 | lavit = LaVITforPigeonDPO( 135 | model_path=model_path, 136 | model_dtype=model_dtype, 137 | lora_weights=os.path.join(pre_ckpt, "llama"), 138 | use_xformers=use_xformers, 139 | check_safety=check_safety, 140 | pixel_decoding=pixel_decoding, 141 | mask_type=mask_type, 142 | num_heads=num_heads, 143 | num_layers=num_layers, 144 | drop_prob=drop_prob, 145 | img_folder_path=img_folder_path, 146 | all_image_paths=all_image_paths, 147 | data_path=data_path, 148 | load_in_8bit=load_in_8bit, 149 | device_map=device_map, 150 | mode=mode, 151 | ) 152 | 153 | if lavit.mask_generator is not None: 154 | # Load mask generator for model 155 | mask_generater_path = os.path.join(pre_ckpt, "mask_generator") 156 | mask_generator_ckpt = os.listdir(mask_generater_path)[0] 157 | if mask_generator_ckpt.endswith(".bin"): 158 | state_dict = torch.load(os.path.join(mask_generater_path, mask_generator_ckpt)) 159 | elif mask_generator_ckpt.endswith(".safetensors"): 160 | state_dict = safetensors.torch.load_file(os.path.join(mask_generater_path, mask_generator_ckpt)) 161 | else: 162 | raise ValueError(f"Unexpected mask generator checkpoint {mask_generator_ckpt} in PATH:{mask_generater_path}.") 163 | 164 | lavit.mask_generator.load_state_dict(state_dict, False) 165 | del state_dict 166 | 167 | adapter_path = os.path.join(pre_ckpt, "adapter") 168 | adapter_ckpt = os.listdir(adapter_path)[0] 169 | if adapter_ckpt.endswith(".bin"): 170 | state_dict = torch.load(os.path.join(adapter_path, adapter_ckpt)) 171 | elif adapter_ckpt.endswith(".safetensors"): 172 | state_dict = safetensors.torch.load_file(os.path.join(adapter_path, adapter_ckpt)) 173 | else: 174 | raise ValueError(f"Unexpected mask generator checkpoint {adapter_ckpt} in PATH:{adapter_path}.") 175 | 176 | lavit.adapter.load_state_dict(state_dict, False) 177 | del state_dict 178 | 179 | if not ddp and torch.cuda.device_count() > 1: 180 | lavit.llama_model.is_parallelizable = True 181 | lavit.llama_model.model_parallel = True 182 | 183 | # Load Dataset 184 | train_data_path = os.path.join(data_path, "data_ready_dpo", "train_ready.npy") 185 | valid_data_path = os.path.join(data_path, "data_ready_dpo", "valid_ready.npy") 186 | if os.path.exists(train_data_path) and os.path.exists(valid_data_path): 187 | logger.info(f"Loading processed train dataset and valid dataset from {os.path.join(data_path, 'data_ready_dpo')}") 188 | train_data = np.load(train_data_path, allow_pickle=True).item() 189 | valid_data = np.load(valid_data_path, allow_pickle=True).item() 190 | else: 191 | logger.info(f"Processed datasets don't exist. Start data processing...") 192 | if scenario == "sticker": 193 | item_info = np.load(os.path.join(data_path, "mapped_anno_dict.npy"), allow_pickle=True).item() 194 | semantics = np.load(os.path.join(data_path, "sticker_semantics.npy"), allow_pickle=True).tolist() 195 | elif scenario == "movie": 196 | item_info = np.load(os.path.join(data_path, "mapped_movies.npy"), allow_pickle=True).item() 197 | semantics = np.load(os.path.join(data_path, "movie_semantics.npy"), allow_pickle=True).tolist() 198 | 199 | train_seq = np.load(os.path.join(data_path, "dpo", "train.npy"), allow_pickle=True).item() 200 | valid_seq = np.load(os.path.join(data_path, "dpo", "valid.npy"), allow_pickle=True).item() 201 | train_preference_data = np.load(os.path.join(data_path, "dpo", "train_preference_data.npy"), allow_pickle=True).item() 202 | valid_preference_data = np.load(os.path.join(data_path, "dpo", "valid_preference_data.npy"), allow_pickle=True).item() 203 | 204 | train_data = data_utils_dpo.process_dpo_data_for(scenario, train_seq, train_preference_data, lavit.llama_tokenizer, item_info, semantics) 205 | valid_data = data_utils_dpo.process_dpo_data_for(scenario, valid_seq, valid_preference_data, lavit.llama_tokenizer, item_info, semantics) 206 | 207 | if not os.path.exists(os.path.join(data_path, "data_ready_dpo")): 208 | os.makedirs(os.path.join(data_path, "data_ready_dpo")) 209 | 210 | np.save(train_data_path, np.array(train_data)) 211 | np.save(valid_data_path, np.array(valid_data)) 212 | logger.info(f"Saving processed train dataset and valid dataset into {os.path.join(data_path, 'processed_seq', 'data_ready_dpo')}") 213 | 214 | train_dataset = data_utils_dpo.DPODataset(train_data) 215 | valid_dataset = data_utils_dpo.DPODataset(valid_data) 216 | 217 | trainer = DPOTrainer( 218 | model=lavit, 219 | train_dataset=train_dataset, 220 | eval_dataset=valid_dataset, 221 | args=PigeonDPOConfig( 222 | # model-specific arguments 223 | hist_mask_ratio=hist_mask_ratio, 224 | add_special=add_special, 225 | # training arguments 226 | mode=mode, 227 | model_adapter_name="policy", 228 | ref_adapter_name="reference", 229 | per_device_train_batch_size=micro_batch_size, 230 | per_device_eval_batch_size=micro_batch_size, 231 | group_by_length=group_by_length, 232 | gradient_accumulation_steps=gradient_accumulation_steps, 233 | warmup_steps=20, 234 | num_train_epochs=num_epochs, 235 | learning_rate=learning_rate, 236 | lr_scheduler_type=lr_schedule_type, 237 | fp16=False, 238 | logging_steps=logging_steps, 239 | optim="adamw_bnb_8bit", 240 | eval_num=eval_num, 241 | evaluation_strategy="steps", 242 | eval_steps=eval_steps, 243 | save_strategy="steps", 244 | save_steps=save_steps, 245 | output_dir=output_dir, 246 | save_total_limit=5, 247 | load_best_model_at_end=True, 248 | ddp_find_unused_parameters=False if ddp else None, 249 | report_to=None, 250 | seed=seed, 251 | log_level="info", 252 | ), 253 | data_collator=data_utils_dpo.DPOCollator(pad_token_id=lavit.llama_tokenizer.pad_token_id, padding_side=lavit.llama_tokenizer.padding_side, return_tensors="pt"), 254 | callbacks = [EarlyStoppingCallback(early_stopping_patience=10)] 255 | ) 256 | 257 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 258 | logger.info(f"BEST MODEL PATH: {trainer.state.best_model_checkpoint}") 259 | 260 | def save_params_to_txt(file_path, **kwargs): 261 | with open(file_path, 'w') as file: 262 | for key, value in kwargs.items(): 263 | file.write(f'{key}: {value}\n') 264 | 265 | def set_random_seed(seed=0): 266 | random.seed(seed) 267 | np.random.seed(seed) 268 | torch.manual_seed(seed) 269 | torch.cuda.manual_seed(seed) 270 | torch.cuda.manual_seed_all(seed) 271 | torch.backends.cudnn.deterministic = True 272 | torch.backends.cudnn.benchmark = False 273 | 274 | if __name__ == "__main__": 275 | fire.Fire(train) 276 | 277 | -------------------------------------------------------------------------------- /Pigeon/models/modeling_decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 8 | 9 | # The vqkd decoder to reconstruct the image semantics 10 | 11 | class LayerNorm(nn.LayerNorm): 12 | """Subclass torch's LayerNorm to handle fp16.""" 13 | 14 | def forward(self, x: torch.Tensor): 15 | orig_type = x.dtype 16 | ret = super().forward(x.type(torch.float32)) 17 | return ret.type(orig_type) 18 | 19 | try: 20 | from apex.normalization import FusedLayerNorm 21 | except: 22 | FusedLayerNorm = LayerNorm 23 | 24 | 25 | class Mlp(nn.Module): 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class SelfAttention(nn.Module): 45 | 46 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 47 | super().__init__() 48 | self.num_heads = num_heads 49 | head_dim = dim // num_heads 50 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 51 | self.scale = qk_scale or head_dim ** -0.5 52 | 53 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 54 | self.attn_drop = nn.Dropout(attn_drop) 55 | self.proj = nn.Linear(dim, dim, bias=qkv_bias) 56 | self.proj_drop = nn.Dropout(proj_drop) 57 | 58 | def forward(self, x): 59 | B, N, C = x.shape 60 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 61 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 62 | 63 | attn = (q @ k.transpose(-2, -1)) * self.scale 64 | 65 | attn = attn.softmax(dim=-1) 66 | attn = self.attn_drop(attn) 67 | 68 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 69 | x = self.proj(x) 70 | x = self.proj_drop(x) 71 | return x 72 | 73 | 74 | class CrossAttention(nn.Module): 75 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 76 | super().__init__() 77 | self.num_heads = num_heads 78 | head_dim = dim // num_heads 79 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 80 | self.scale = qk_scale or head_dim ** -0.5 81 | 82 | self.query = nn.Linear(dim, dim, bias=qkv_bias) 83 | self.key = nn.Linear(dim, dim, bias=qkv_bias) 84 | self.value = nn.Linear(dim, dim, bias=qkv_bias) 85 | 86 | self.attn_drop = nn.Dropout(attn_drop) 87 | self.proj = nn.Linear(dim, dim, bias=qkv_bias) 88 | self.proj_drop = nn.Dropout(proj_drop) 89 | 90 | def forward(self, x, codebook_embeds, codebook_mask): 91 | B, N, C = codebook_embeds.shape 92 | _, N_x, _ = x.shape 93 | 94 | q = self.query(x).reshape(B, N_x, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 95 | k = self.key(codebook_embeds).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 96 | v = self.value(codebook_embeds).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 97 | 98 | attn = (q @ k.transpose(-2, -1)) * self.scale 99 | 100 | extended_mask = codebook_mask[:, None, None, :] 101 | extended_mask = (1.0 - extended_mask) * -10000.0 102 | attn = attn + extended_mask 103 | 104 | attn = attn.softmax(dim=-1) 105 | attn = self.attn_drop(attn) 106 | x = (attn @ v).transpose(1, 2).reshape(B, N_x, C) 107 | x = self.proj(x) 108 | x = self.proj_drop(x) 109 | return x 110 | 111 | 112 | class Block(nn.Module): 113 | 114 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 115 | attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 116 | super().__init__() 117 | self.norm0 = norm_layer(dim) 118 | self.self_attn = SelfAttention( 119 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 120 | attn_drop=attn_drop, proj_drop=drop 121 | ) 122 | 123 | self.norm1 = norm_layer(dim) 124 | self.cross_attn = CrossAttention( 125 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 126 | attn_drop=attn_drop, proj_drop=drop 127 | ) 128 | 129 | self.norm2 = norm_layer(dim) 130 | mlp_hidden_dim = int(dim * mlp_ratio) 131 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 132 | 133 | def forward(self, x, codebook_embeds, codebook_mask): 134 | x = x + self.self_attn(self.norm0(x)) 135 | x = x + self.cross_attn(self.norm1(x), codebook_embeds, codebook_mask) 136 | x = x + self.mlp(self.norm2(x)) 137 | return x 138 | 139 | 140 | class AttentionPool2d(nn.Module): 141 | def __init__(self, seq_len: int, embed_dim: int, num_heads: int, output_dim: int = None): 142 | super().__init__() 143 | self.positional_embedding = nn.Parameter(torch.randn(seq_len + 1, embed_dim) / embed_dim ** 0.5) 144 | self.k_proj = nn.Linear(embed_dim, embed_dim) 145 | self.q_proj = nn.Linear(embed_dim, embed_dim) 146 | self.v_proj = nn.Linear(embed_dim, embed_dim) 147 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 148 | self.num_heads = num_heads 149 | 150 | def forward(self, x, return_all_tokens=False): 151 | # x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 152 | x = x.permute(1, 0, 2) # (N(HW)C) => (HW)NC 153 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 154 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 155 | x, _ = F.multi_head_attention_forward( 156 | query=x, key=x, value=x, 157 | embed_dim_to_check=x.shape[-1], 158 | num_heads=self.num_heads, 159 | q_proj_weight=self.q_proj.weight, 160 | k_proj_weight=self.k_proj.weight, 161 | v_proj_weight=self.v_proj.weight, 162 | in_proj_weight=None, 163 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 164 | bias_k=None, 165 | bias_v=None, 166 | add_zero_attn=False, 167 | dropout_p=0, 168 | out_proj_weight=self.c_proj.weight, 169 | out_proj_bias=self.c_proj.bias, 170 | use_separate_proj_weight=True, 171 | training=self.training, 172 | need_weights=False 173 | ) 174 | if return_all_tokens: 175 | return x 176 | else: 177 | return x[0] 178 | 179 | 180 | class VQDecoder(nn.Module): 181 | def __init__(self, img_size=224, patch_size=14, in_chans=32, embed_dim=1408, 182 | depth=12, num_heads=16, mlp_ratio=4.3637, qkv_bias=True, qk_scale=None, drop_rate=0., 183 | attn_drop_rate=0., norm_layer=partial(FusedLayerNorm, eps=1e-5), **kwargs): 184 | super().__init__() 185 | 186 | self.in_proj = nn.Linear(in_chans, embed_dim) 187 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 188 | num_patches = (img_size // patch_size) * (img_size // patch_size) 189 | self.num_patches = num_patches 190 | 191 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # The postion embedding for the latent code 192 | 193 | self.query_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # The query embedding for reconstruction 194 | 195 | self.pos_drop = nn.Dropout(p=drop_rate) 196 | 197 | self.blocks = nn.ModuleList([ 198 | Block( 199 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 200 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) 201 | for i in range(depth)]) 202 | 203 | self.norm = norm_layer(embed_dim) 204 | 205 | # The decoder task layer 206 | self.decoder_out_dim = 1408 207 | self.decode_task_layer = nn.Sequential( 208 | nn.Linear(embed_dim, embed_dim), 209 | nn.Tanh(), 210 | nn.Linear(embed_dim, self.decoder_out_dim), 211 | ) 212 | 213 | self.unet_proj = nn.Linear(self.decoder_out_dim, 768) 214 | 215 | def get_num_layers(self): 216 | return len(self.blocks) 217 | 218 | @torch.jit.ignore 219 | def no_weight_decay(self): 220 | return {'pos_embed', 'query_embed'} 221 | 222 | def forward(self, x, token_num): 223 | # codebook_fea 224 | # B, nc, w, h = codebook_fea.shape 225 | x = self.in_proj(x) 226 | B = len(token_num) 227 | num_tokens, C = x.shape 228 | device = x.device 229 | 230 | x_list = torch.split(x, token_num.tolist(), dim=0) 231 | max_token_num = token_num.max().item() 232 | x_pad = torch.zeros(B, max_token_num, C, dtype=x.dtype).to(device) 233 | mask = torch.zeros(B, max_token_num, dtype=x.dtype).to(device) 234 | 235 | for i, x_tensor in enumerate(x_list): 236 | x_pad[i][:len(x_tensor)] = x_tensor 237 | mask[i][:len(x_tensor)] = 1 238 | 239 | x_pad = x_pad + self.pos_embed[:,:max_token_num] 240 | x_pad = self.pos_drop(x_pad) 241 | 242 | query_embeds = self.query_embed.expand(B, -1, -1) 243 | 244 | for blk in self.blocks: 245 | query_embeds = blk(query_embeds, codebook_embeds=x_pad, 246 | codebook_mask=mask) 247 | 248 | query_embeds = self.norm(query_embeds) # To align with the raw vit features 249 | 250 | visual_rec = self.decode_task_layer(query_embeds) 251 | 252 | visual_rec = self.unet_proj(visual_rec) 253 | 254 | return visual_rec 255 | 256 | 257 | class HighresVQDecoder(nn.Module): 258 | def __init__(self, img_size=224, patch_size=14, in_chans=32, embed_dim=1408, 259 | depth=12, num_heads=16, mlp_ratio=4.3637, qkv_bias=True, qk_scale=None, drop_rate=0., 260 | attn_drop_rate=0., norm_layer=partial(FusedLayerNorm, eps=1e-5), **kwargs): 261 | super().__init__() 262 | 263 | self.in_proj = nn.Linear(in_chans, embed_dim) 264 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 265 | num_patches = (img_size // patch_size) * (img_size // patch_size) 266 | self.num_patches = num_patches 267 | 268 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # The postion embedding for the latent code 269 | 270 | self.query_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # The query embedding for reconstruction 271 | 272 | self.pos_drop = nn.Dropout(p=drop_rate) 273 | 274 | self.blocks = nn.ModuleList([ 275 | Block( 276 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 277 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) 278 | for i in range(depth)]) 279 | 280 | self.norm = norm_layer(embed_dim) 281 | 282 | # The decoder task layer 283 | self.decoder_out_dim = 1408 284 | self.decode_task_layer = nn.Sequential( 285 | nn.Linear(embed_dim, embed_dim), 286 | nn.Tanh(), 287 | nn.Linear(embed_dim, self.decoder_out_dim), 288 | ) 289 | 290 | # Convert the decoded features to Unet Condition 291 | self.unet_proj_1 = nn.Linear(self.decoder_out_dim, 768) 292 | self.unet_proj_2 = nn.Linear(self.decoder_out_dim, 1280) 293 | self.unet_attnpool = AttentionPool2d(num_patches, self.decoder_out_dim, num_heads, 1280) 294 | 295 | def get_num_layers(self): 296 | return len(self.blocks) 297 | 298 | @torch.jit.ignore 299 | def no_weight_decay(self): 300 | return {'pos_embed', 'query_embed'} 301 | 302 | def forward(self, x, token_num): 303 | # codebook_fea 304 | # B, nc, w, h = codebook_fea.shape 305 | x = self.in_proj(x) # codebook_emb_dim=32, --> 1408. [total_token_num, 1408] 306 | B = len(token_num) 307 | num_tokens, C = x.shape 308 | device = x.device 309 | 310 | x_list = torch.split(x, token_num.tolist(), dim=0) 311 | max_token_num = token_num.max().item() 312 | x_pad = torch.zeros(B, max_token_num, C, dtype=x.dtype).to(device) 313 | mask = torch.zeros(B, max_token_num, dtype=x.dtype).to(device) 314 | 315 | for i, x_tensor in enumerate(x_list): 316 | x_pad[i][:len(x_tensor)] = x_tensor # padding_side: right. [bsz, max_token_num, 1408] 317 | mask[i][:len(x_tensor)] = 1 318 | 319 | x_pad = x_pad + self.pos_embed[:,:max_token_num] 320 | x_pad = self.pos_drop(x_pad) 321 | 322 | query_embeds = self.query_embed.expand(B, -1, -1) 323 | 324 | for blk in self.blocks: 325 | query_embeds = blk(query_embeds, codebook_embeds=x_pad, 326 | codebook_mask=mask) 327 | 328 | query_embeds = self.norm(query_embeds) # To align with the raw vit features 329 | 330 | visual_rec = self.decode_task_layer(query_embeds) 331 | 332 | encoder_hidden_1 = self.unet_proj_1(visual_rec) # [bs, 256, 768] 333 | encoder_hidden_2 = self.unet_proj_2(visual_rec) # [bs, 256, 1280] 334 | prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048] 335 | pooled_prompt_embeds = self.unet_attnpool(visual_rec) # [bs, 1280] 336 | 337 | return prompt_embeds, pooled_prompt_embeds 338 | 339 | 340 | def build_tokenizer_decoder(model_path='', pixel_decoding='highres'): 341 | if pixel_decoding == 'lowres': 342 | model = VQDecoder(depth=12) 343 | weight_path = os.path.join(model_path, 'visual_tokenizer', 'tokenizer_decoder.bin') 344 | else: 345 | model = HighresVQDecoder(depth=12) 346 | weight_path = os.path.join(model_path, 'visual_tokenizer', 'highres_tokenizer_decoder.bin') 347 | 348 | print(f"Load visual tokenizer decoder weight from {weight_path}") 349 | state_dict = torch.load(weight_path, map_location="cpu") 350 | model.load_state_dict(state_dict) 351 | return model -------------------------------------------------------------------------------- /Pigeon/models/lavit_for_understanding.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import contextlib 3 | import os 4 | import re 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.cuda.amp import autocast as autocast 9 | from transformers import LlamaForCausalLM, LlamaTokenizer 10 | from utils import get_rank 11 | from models.modeling_visual_tokenzier import build_dynamic_tokenizer 12 | from models.transform import LaVITImageProcessor, LaVITQuestionProcessor 13 | from PIL import Image 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | class LaVITforUnderstanding(nn.Module): 19 | """ 20 | The LaVIT Model for Multi-modal Understanding, 21 | this file is used for reading image contents and answering the questions. 22 | """ 23 | def __init__( 24 | self, 25 | img_size=224, 26 | model_path="", 27 | model_dtype="bf16", 28 | device_id=None, 29 | apply_lemmatizer=True, 30 | use_xformers=False, 31 | model_sub_dir='language_model', 32 | ): 33 | """ 34 | img_size: The input image size, should be 224 * 224 35 | model_path: The pre-trained model checkpoint path, the local path for downloaded LaVIT weight 36 | model_dtype: The precision of model weight during inference, should be set bf16 or fp16, default is bf16. 37 | apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas 38 | """ 39 | super().__init__() 40 | assert img_size == 224, "Input Image Size should be set to 224" 41 | 42 | visual_vocab_size = 16384 # The visual vocab size of LaVIT is 16384 43 | print(f"Loading LaVIT Model Weight from {model_path}, model precision: {model_dtype}") 44 | 45 | if device_id is None: 46 | device_map={"": get_rank() % 8} 47 | else: 48 | device_map={"": device_id} 49 | 50 | self.llama_tokenizer = LlamaTokenizer.from_pretrained(model_path, subfolder=model_sub_dir, use_fast=False) 51 | self.llama_model = LlamaForCausalLM.from_pretrained( 52 | model_path, subfolder=model_sub_dir, torch_dtype=torch.bfloat16 if model_dtype=="bf16" else torch.float16, 53 | device_map=device_map, 54 | ) 55 | for name, param in self.llama_model.named_parameters(): 56 | param.requires_grad = False 57 | 58 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token 59 | self.visual_vocab_size = visual_vocab_size 60 | print(f"The Visual Vocab Size is {self.visual_vocab_size}") 61 | print(f"The llama tokenizer vocab size is {len(self.llama_tokenizer)}") 62 | 63 | self.visual_tokenizer = build_dynamic_tokenizer(model_path, use_xformers=use_xformers, 64 | for_understanding=True, model_sub_dir=model_sub_dir) 65 | self.model_dtype = model_dtype 66 | self.apply_lemmatizer = apply_lemmatizer 67 | self._lemmatizer = None 68 | self.processer = LaVITImageProcessor(image_size=img_size) 69 | 70 | @property 71 | def device(self): 72 | return list(self.parameters())[0].device 73 | 74 | @property 75 | def dtype(self): 76 | if self.model_dtype == 'fp16': 77 | dtype = torch.float16 78 | elif self.model_dtype == 'bf16': 79 | dtype = torch.bfloat16 80 | else: 81 | # The default dtype is fp16 82 | dtype = torch.float16 83 | return dtype 84 | 85 | def maybe_autocast(self, dtype=torch.float16): 86 | # if on cpu, don't use autocast 87 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 88 | enable_autocast = self.device != torch.device("cpu") 89 | dtype = self.dtype 90 | 91 | if enable_autocast: 92 | return torch.cuda.amp.autocast(dtype=dtype) 93 | else: 94 | return contextlib.nullcontext() 95 | 96 | def process_image(self, image_inputs): 97 | if isinstance(image_inputs, torch.Tensor): 98 | assert len(image_inputs.shape) == 4, "Image Tensors should have shape (batch_size, 3, H, W)" 99 | image_inputs = image_inputs.to(self.device) 100 | return image_inputs 101 | 102 | if not isinstance(image_inputs, list): 103 | assert isinstance(image_inputs, str) 104 | image_inputs = [image_inputs] 105 | 106 | image_tensors = [] 107 | for image_path in image_inputs: 108 | image = Image.open(image_path).convert('RGB') 109 | image = self.processer(image) 110 | image_tensors.append(image) 111 | 112 | image_tensors = torch.stack(image_tensors, dim=0) 113 | image_tensors = image_tensors.to(self.device) 114 | return image_tensors 115 | 116 | def compute_dynamic_visual_embeds(self, image): 117 | image_embeds_list = self.visual_tokenizer.encode_features(image) 118 | batch_size = len(image_embeds_list) 119 | # Pad the image start and end tokens 120 | image_pad_token = torch.tensor([32000, 32001], dtype=torch.long).to(image.device) 121 | image_pad_embeds = self.llama_model.get_input_embeddings()(image_pad_token) # [2, embed_dim] 122 | max_token_num = -1 123 | 124 | for i_b in range(batch_size): 125 | image_embeds_list[i_b] = torch.cat([image_pad_embeds[:1], image_embeds_list[i_b], image_pad_embeds[1:]], dim=0) 126 | max_token_num = max(max_token_num, len(image_embeds_list[i_b])) 127 | 128 | # Pad with eos embeddings 129 | eos_id = self.llama_tokenizer.eos_token_id 130 | eos_id = torch.tensor([eos_id], dtype=torch.long).to(image.device) 131 | eos_embeds = self.llama_model.get_input_embeddings()(eos_id).unsqueeze(0) # [1, 1, embed_dim] 132 | 133 | image_attns = torch.zeros((batch_size, max_token_num), dtype=torch.long).to(image.device) 134 | image_embeds = eos_embeds.repeat(batch_size, max_token_num, 1) 135 | 136 | # Use the left padding 137 | for i_b in range(batch_size): 138 | image_attns[i_b, -len(image_embeds_list[i_b]):] = 1 139 | image_embeds[i_b, -len(image_embeds_list[i_b]):] = image_embeds_list[i_b] 140 | 141 | return image_embeds, image_attns 142 | 143 | @torch.no_grad() 144 | def generate( 145 | self, 146 | samples, 147 | use_nucleus_sampling=False, 148 | num_beams=2, 149 | max_length=36, 150 | min_length=8, 151 | top_p=1.0, 152 | top_k=50, 153 | repetition_penalty=1, 154 | length_penalty=1, 155 | num_captions=1, 156 | temperature=1, 157 | **kwargs 158 | ): 159 | """ 160 | Usage: 161 | Generate the textual caption of input images 162 | Args: 163 | samples (dict): A dictionary containing the following keys: 164 | - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) 165 | use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. 166 | num_beams (int): Number of beams for beam search. 1 means no beam search. 167 | max_length (int): The maximum length of the sequence to be generated. 168 | min_length (int): The minimum length of the sequence to be generated. 169 | top_p (float): The cumulative probability for nucleus sampling. 170 | repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. 171 | num_captions (int): Number of captions to be generated for each image. 172 | Returns: 173 | captions (list): A list of strings of length batch_size * num_captions. 174 | """ 175 | 176 | image = self.process_image(samples["image"]) 177 | 178 | if "prompt" in samples.keys(): 179 | prompt = samples["prompt"] 180 | else: 181 | prompt = '' 182 | 183 | # Prepare image token ids 184 | with self.maybe_autocast(): 185 | image_embeds, image_attns = self.compute_dynamic_visual_embeds(image) 186 | 187 | if prompt != "": 188 | if isinstance(prompt, str): 189 | prompt = [prompt] * image.size(0) 190 | else: 191 | assert len(prompt) == image.size( 192 | 0 193 | ), "The number of prompts must be equal to the batch size." 194 | 195 | self.llama_tokenizer.padding_side = "left" 196 | prompt_tokens = self.llama_tokenizer( 197 | prompt, padding="longest", return_tensors="pt", add_special_tokens=False 198 | ).to(image.device) 199 | 200 | with self.maybe_autocast(): 201 | prompt_embeds = self.llama_model.get_input_embeddings()(prompt_tokens.input_ids) 202 | inputs_embeds = torch.cat([image_embeds, prompt_embeds], dim=1) 203 | attention_mask = torch.cat([image_attns, prompt_tokens.attention_mask], dim=1) 204 | 205 | else: 206 | inputs_embeds = image_embeds 207 | attention_mask = image_attns 208 | 209 | # For captioning, supress the token ids > 32000 (Visual Tokens) 210 | supress_range = 32000 + self.visual_vocab_size + 2 211 | suppress_tokens = [x for x in range(32000, supress_range)] 212 | 213 | with self.maybe_autocast(): 214 | outputs = self.llama_model.generate( 215 | inputs_embeds=inputs_embeds, 216 | attention_mask=attention_mask, 217 | do_sample=use_nucleus_sampling, 218 | temperature=temperature, 219 | num_beams=num_beams, 220 | max_new_tokens=max_length, 221 | min_new_tokens=min_length, 222 | suppress_tokens=suppress_tokens, 223 | bos_token_id=self.llama_tokenizer.bos_token_id, 224 | eos_token_id=self.llama_tokenizer.eos_token_id, 225 | pad_token_id=self.llama_tokenizer.pad_token_id, 226 | repetition_penalty=repetition_penalty, 227 | length_penalty=length_penalty, 228 | num_return_sequences=num_captions, 229 | ) 230 | 231 | output_text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True) 232 | output_text = [text.strip() for text in output_text] 233 | output_text = [text.split('.')[0] for text in output_text] 234 | return output_text 235 | 236 | def pad_input_embeds(self, image_embeds, image_attns, prompt_embeds, prompt_attns): 237 | # Concat the image and text embeddings 238 | batch_size = len(image_embeds) 239 | input_embeds, attention_mask = [], [] 240 | 241 | for i_b in range(batch_size): 242 | image_embed = image_embeds[i_b] 243 | image_attn = image_attns[i_b] 244 | prompt_embed = prompt_embeds[i_b] # [seq_len, embed_dim] 245 | prompt_attn = prompt_attns[i_b] # [seq_len] 246 | 247 | prompt_len = prompt_attn.sum().item() 248 | pad_prompt_len = len(prompt_attn) - prompt_len 249 | 250 | if pad_prompt_len == 0: 251 | input_embed = torch.cat([image_embed, prompt_embed], dim=0) 252 | input_attn = torch.cat([image_attn, prompt_attn], dim=0) 253 | else: 254 | assert prompt_attn[:pad_prompt_len].sum() == 0 255 | input_embed = torch.cat([prompt_embed[:pad_prompt_len], image_embed, prompt_embed[-prompt_len:]], dim=0) 256 | input_attn = torch.cat([prompt_attn[:pad_prompt_len], image_attn, prompt_attn[-prompt_len:]], dim=0) 257 | 258 | input_embeds.append(input_embed) 259 | attention_mask.append(input_attn) 260 | 261 | input_embeds = torch.stack(input_embeds, dim=0) 262 | attention_mask = torch.stack(attention_mask, dim=0) 263 | 264 | return input_embeds, attention_mask 265 | 266 | @torch.no_grad() 267 | def predict_answers( 268 | self, 269 | samples, 270 | num_beams=5, 271 | max_len=10, 272 | min_len=2, 273 | prompt="Question: {} Answer:", 274 | temperature=1, 275 | top_p=1.0, 276 | top_k=50, 277 | use_nucleus_sampling=False, 278 | length_penalty=0, 279 | **kwargs 280 | ): 281 | """ 282 | Usage: 283 | Answering the visual questions 284 | Args: 285 | samples (dict): A dictionary containing the following keys: 286 | - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) 287 | use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. 288 | num_beams (int): Number of beams for beam search. 1 means no beam search. 289 | max_length (int): The maximum length of the sequence to be generated. 290 | min_length (int): The minimum length of the sequence to be generated. 291 | top_p (float): The cumulative probability for nucleus sampling. 292 | repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. 293 | Returns: 294 | answers (list): A list of strings of length batch_size. 295 | """ 296 | image = self.process_image(samples["image"]) 297 | 298 | if isinstance(samples["text_input"], str): 299 | samples["text_input"] = [samples["text_input"]] 300 | 301 | if prompt: 302 | text_input = [prompt.format(question) for question in samples["text_input"]] 303 | else: 304 | text_input = samples["text_input"] 305 | 306 | self.llama_tokenizer.padding_side = "left" 307 | prompt_tokens = self.llama_tokenizer( 308 | text_input, padding="longest", return_tensors="pt", add_special_tokens=False 309 | ).to(image.device) 310 | 311 | with self.maybe_autocast(): 312 | prompt_embeds = self.llama_model.get_input_embeddings()(prompt_tokens.input_ids) 313 | image_embeds, image_attns = self.compute_dynamic_visual_embeds(image) 314 | 315 | # Concat the image and text embeddings to form left padding 316 | inputs_embeds, attention_mask = self.pad_input_embeds(image_embeds, image_attns, prompt_embeds, prompt_tokens.attention_mask) 317 | 318 | supress_range = 32000 + self.visual_vocab_size + 2 319 | suppress_tokens = [x for x in range(32000, supress_range)] 320 | 321 | with self.maybe_autocast(): 322 | outputs = self.llama_model.generate( 323 | inputs_embeds=inputs_embeds, 324 | attention_mask=attention_mask, 325 | do_sample=use_nucleus_sampling, 326 | temperature=temperature, 327 | num_beams=num_beams, 328 | max_new_tokens=max_len, 329 | min_new_tokens=min_len, 330 | suppress_tokens=suppress_tokens, 331 | bos_token_id=self.llama_tokenizer.bos_token_id, 332 | eos_token_id=self.llama_tokenizer.eos_token_id, 333 | pad_token_id=self.llama_tokenizer.pad_token_id, 334 | length_penalty=length_penalty, 335 | early_stopping=True, 336 | ) 337 | 338 | # print("output: ", outputs) 339 | output_text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True) 340 | output_text = [text.strip() for text in output_text] 341 | # The post posting for evaluation 342 | output_text = [text.split('\n')[0] for text in output_text] 343 | output_text = [text.split('question:')[0] for text in output_text] 344 | output_text = [text.split('Long answer:')[0] for text in output_text] 345 | output_text = [text.split(',')[0] for text in output_text] 346 | output_text = [text.split('.')[0] for text in output_text] 347 | 348 | # lemmatize the output 349 | output_text = self._lemmatize(output_text) 350 | 351 | return output_text 352 | 353 | def _lemmatize(self, answers): 354 | def apply(answer): 355 | answer = answer.lower() 356 | doc = self.lemmatizer(answer) 357 | 358 | words = [] 359 | for token in doc: 360 | if token.pos_ in ["NOUN", "VERB"]: 361 | words.append(token.lemma_) 362 | else: 363 | words.append(token.text) 364 | answer = " ".join(words) 365 | 366 | return answer 367 | 368 | return [apply(answer) for answer in answers] 369 | 370 | @property 371 | def lemmatizer(self): 372 | if self._lemmatizer is None: 373 | try: 374 | import spacy 375 | self._lemmatizer = spacy.load("en_core_web_sm") 376 | except ImportError: 377 | logging.error( 378 | """ 379 | Please install spacy and en_core_web_sm model to apply lemmatization. 380 | python -m spacy download en_core_web_sm 381 | OR 382 | import spacy.cli 383 | spacy.cli.download("en_core_web_sm") 384 | """ 385 | ) 386 | exit(1) 387 | 388 | return self._lemmatizer -------------------------------------------------------------------------------- /Pigeon/models/modeling_visual_tokenzier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | from collections import OrderedDict 9 | from functools import partial, reduce 10 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 11 | from models.modeling_visual_encoder import build_eva_clip 12 | from torchvision import transforms as pth_transforms 13 | from torchvision.transforms.functional import InterpolationMode 14 | 15 | 16 | class LayerNorm(nn.LayerNorm): 17 | """Subclass torch's LayerNorm to handle fp16.""" 18 | 19 | def forward(self, x: torch.Tensor): 20 | orig_type = x.dtype 21 | ret = super().forward(x.type(torch.float32)) 22 | return ret.type(orig_type) 23 | 24 | 25 | try: 26 | from apex.normalization import FusedLayerNorm 27 | except: 28 | FusedLayerNorm = LayerNorm 29 | 30 | 31 | class Mlp(nn.Module): 32 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 33 | super().__init__() 34 | out_features = out_features or in_features 35 | hidden_features = hidden_features or in_features 36 | self.fc1 = nn.Linear(in_features, hidden_features) 37 | self.act = act_layer() 38 | self.fc2 = nn.Linear(hidden_features, out_features) 39 | self.drop = nn.Dropout(drop) 40 | 41 | def forward(self, x): 42 | x = self.fc1(x) 43 | x = self.act(x) 44 | x = self.drop(x) 45 | x = self.fc2(x) 46 | x = self.drop(x) 47 | return x 48 | 49 | 50 | def l2norm(t): 51 | return F.normalize(t, p = 2, dim = -1) 52 | 53 | 54 | class CodebookEmbedding(nn.Module): 55 | def __init__(self, num_tokens, codebook_dim): 56 | super().__init__() 57 | self.num_tokens = num_tokens 58 | self.codebook_dim = codebook_dim 59 | weight = torch.randn(num_tokens, codebook_dim) 60 | weight = l2norm(weight) 61 | self.weight = nn.Parameter(weight) 62 | 63 | def forward(self, embed_id): 64 | return F.embedding(embed_id, self.weight) 65 | 66 | 67 | class VectorQuantizer(nn.Module): 68 | def __init__(self, n_embed, embedding_dim): 69 | super().__init__() 70 | self.codebook_dim = embedding_dim 71 | self.num_tokens = n_embed 72 | self.embedding = CodebookEmbedding(self.num_tokens, self.codebook_dim) 73 | 74 | def tokenize(self, z): 75 | z = l2norm(z) 76 | z_flattened = z.reshape(-1, self.codebook_dim) 77 | 78 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 79 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 80 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 81 | 82 | encoding_indices = torch.argmin(d, dim=1) 83 | 84 | z_q = self.embedding(encoding_indices) # [np, d] 85 | 86 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) # [np, 16384] 87 | 88 | return z_q, encoding_indices 89 | 90 | def get_quantize_from_id(self, encoding_indices): 91 | z_q = self.embedding(encoding_indices) # [np, d] 92 | 93 | 94 | class TokenCrossAttention(nn.Module): 95 | 96 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 97 | super().__init__() 98 | self.num_heads = num_heads 99 | head_dim = dim // num_heads 100 | self.scale = qk_scale or head_dim ** -0.5 101 | 102 | self.query = nn.Linear(dim, dim, bias=qkv_bias) 103 | self.key = nn.Linear(dim, dim, bias=qkv_bias) 104 | self.value = nn.Linear(dim, dim, bias=qkv_bias) 105 | 106 | self.attn_drop = nn.Dropout(attn_drop) 107 | self.proj = nn.Linear(dim, dim, bias=qkv_bias) 108 | self.proj_drop = nn.Dropout(proj_drop) 109 | 110 | def softmax_with_policy(self, attn, policy, eps=1e-6): 111 | B, N = policy.size() 112 | B, H, N, N = attn.size() 113 | fuse_policy = 1 - policy # Each token only attend to the dropped tokens 114 | attn_policy = fuse_policy.reshape(B, 1, 1, N) # * policy.reshape(B, 1, N, 1) 115 | attn_policy = attn_policy.expand(B, 1, N, N) 116 | max_att = torch.max(attn, dim=-1, keepdim=True)[0] 117 | attn = attn - max_att 118 | 119 | # for stable training 120 | attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32) 121 | attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps) 122 | 123 | return attn.type_as(max_att) 124 | 125 | def forward(self, x, x_origin, decisions): 126 | B, N, C = x.shape 127 | q = self.query(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 128 | k = self.key(x_origin).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 129 | v = self.value(x_origin).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 130 | 131 | attn = (q @ k.transpose(-2, -1)) * self.scale 132 | attn = self.softmax_with_policy(attn, decisions) 133 | attn = self.attn_drop(attn) 134 | 135 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 136 | x = self.proj(x) 137 | x = self.proj_drop(x) 138 | return x 139 | 140 | 141 | class TokenCausalAttention(nn.Module): 142 | 143 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 144 | super().__init__() 145 | self.num_heads = num_heads 146 | head_dim = dim // num_heads 147 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 148 | self.scale = qk_scale or head_dim ** -0.5 149 | 150 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 151 | self.attn_drop = nn.Dropout(attn_drop) 152 | self.proj = nn.Linear(dim, dim, bias=qkv_bias) 153 | self.proj_drop = nn.Dropout(proj_drop) 154 | 155 | def softmax_with_policy(self, attn, policy, eps=1e-6): 156 | B, N = policy.size() 157 | device = attn.device 158 | assert attn.shape[-1] == attn.shape[-2] 159 | assert attn.shape[-2] == N 160 | B, H, N, N = attn.size() 161 | 162 | attn_policy = policy.reshape(B, 1, 1, N) # * policy.reshape(B, 1, N, 1) 163 | eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N) 164 | attn_policy = attn_policy + (1.0 - attn_policy) * eye 165 | 166 | # Use the causal attention 167 | seq_ids = torch.arange(N, device=device) 168 | causal_mask = ( 169 | seq_ids[None, None, :].repeat(B, N, 1) 170 | <= seq_ids[None, :, None] 171 | ) 172 | causal_mask = causal_mask[:,None,:,:].to(attn_policy.dtype) 173 | attn_policy = attn_policy * causal_mask 174 | 175 | max_att = torch.max(attn, dim=-1, keepdim=True)[0] 176 | attn = attn - max_att 177 | 178 | # for stable training 179 | attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32) 180 | attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps) 181 | return attn.type_as(max_att) 182 | 183 | def forward(self, x, decisions): 184 | B, N, C = x.shape 185 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 186 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 187 | # v:[B, self.num_heads, N, C // self.num_heads] 188 | attn = (q @ k.transpose(-2, -1)) * self.scale # [B, self.num_heads, N, N] 189 | 190 | if decisions is None: 191 | attn = attn.softmax(dim=-1) 192 | else: 193 | attn = self.softmax_with_policy(attn, decisions) 194 | 195 | attn = self.attn_drop(attn) 196 | 197 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 198 | x = self.proj(x) 199 | x = self.proj_drop(x) 200 | return x 201 | 202 | 203 | class CausalFuserBlock(nn.Module): 204 | 205 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 206 | attn_drop=0., act_layer=nn.GELU, norm_layer=partial(FusedLayerNorm, eps=1e-5)): 207 | super().__init__() 208 | 209 | self.norm0 = norm_layer(dim) 210 | self.token_causal_attn = TokenCausalAttention( 211 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 212 | attn_drop=attn_drop, proj_drop=drop, 213 | ) 214 | 215 | self.norm1 = norm_layer(dim) 216 | self.token_cross_attn = TokenCrossAttention( 217 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 218 | attn_drop=attn_drop, proj_drop=drop 219 | ) 220 | 221 | self.norm2 = norm_layer(dim) 222 | mlp_hidden_dim = int(dim * mlp_ratio) 223 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 224 | 225 | def forward(self, x, x_origin, decisions): 226 | x = x + self.token_causal_attn(self.norm0(x), decisions) 227 | x = x + self.token_cross_attn(self.norm1(x), x_origin, decisions) 228 | x = x + self.mlp(self.norm2(x)) 229 | return x 230 | 231 | 232 | class TokenMerger(nn.Module): 233 | 234 | def __init__(self, dim, num_heads, depth=1, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 235 | attn_drop=0., act_layer=nn.GELU, norm_layer=partial(FusedLayerNorm, eps=1e-5)): 236 | super().__init__() 237 | self.blocks = nn.ModuleList([ 238 | CausalFuserBlock( 239 | dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 240 | drop=drop, attn_drop=attn_drop, act_layer=act_layer, norm_layer=norm_layer) 241 | for i in range(depth)]) 242 | 243 | self.ln_vision = norm_layer(dim) 244 | 245 | self.norm = norm_layer(dim) 246 | 247 | self.apply(self._init_weights) 248 | 249 | def _init_weights(self, m): 250 | if isinstance(m, nn.Linear): 251 | trunc_normal_(m.weight, std=.02) 252 | if isinstance(m, nn.Linear) and m.bias is not None: 253 | nn.init.constant_(m.bias, 0) 254 | elif isinstance(m, nn.LayerNorm): 255 | nn.init.constant_(m.bias, 0) 256 | nn.init.constant_(m.weight, 1.0) 257 | 258 | def forward(self, x, decisions): 259 | x_origin = self.ln_vision(x) # the raw vit features needs layer normalization 260 | 261 | for blk in self.blocks: 262 | x = blk(x, x_origin, decisions) 263 | 264 | x = self.norm(x) # the post norm, for next stage use 265 | 266 | return x 267 | 268 | 269 | class TokenPredictor(nn.Module): 270 | 271 | def __init__(self, embed_dim=384): 272 | super().__init__() 273 | self.in_conv = nn.Sequential( 274 | FusedLayerNorm(embed_dim, eps=1e-5), 275 | nn.Linear(embed_dim, embed_dim), 276 | nn.GELU() 277 | ) 278 | 279 | self.out_conv = nn.Sequential( 280 | nn.Linear(embed_dim, embed_dim // 2), 281 | nn.GELU(), 282 | nn.Linear(embed_dim // 2, embed_dim // 4), 283 | nn.GELU(), 284 | nn.Linear(embed_dim // 4, 2), 285 | nn.LogSoftmax(dim=-1) 286 | ) 287 | 288 | self.apply(self._init_weights) 289 | 290 | def _init_weights(self, m): 291 | if isinstance(m, nn.Linear): 292 | trunc_normal_(m.weight, std=.02) 293 | if isinstance(m, nn.Linear) and m.bias is not None: 294 | nn.init.constant_(m.bias, 0) 295 | elif isinstance(m, nn.LayerNorm): 296 | nn.init.constant_(m.bias, 0) 297 | nn.init.constant_(m.weight, 1.0) 298 | 299 | def forward(self, x, policy): 300 | x = self.in_conv(x) 301 | B, N, C = x.size() 302 | local_x = x[:,:, :C//2] 303 | global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True) 304 | x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1) 305 | return self.out_conv(x) 306 | 307 | 308 | class DynamicVisualTokenizer(nn.Module): 309 | def __init__(self, img_size=224, patch_size=14, width=1408, layers=12, 310 | heads=16, n_code=16384, code_dim=32, model_path='', use_xformers=False): 311 | """ 312 | The dynamic visual tokenizer in LaVIT, it has 12 transformer blocks 313 | """ 314 | super().__init__() 315 | 316 | self.encoder = build_eva_clip(model_path=model_path, use_xformers=use_xformers) 317 | self.encoder.eval() 318 | # Freeze the vit encoder 319 | for param in self.encoder.parameters(): 320 | param.requires_grad = False # fix encoder model 321 | 322 | encoder_config = dict(img_size=224, patch_size=14, in_chans=32, embed_dim=1408, depth=12, num_heads=16, 323 | mlp_ratio=4.3637, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., norm_layer=partial(FusedLayerNorm, eps=1e-5)) 324 | 325 | encoder_config['img_size'] = img_size 326 | encoder_config['patch_size'] = patch_size 327 | encoder_config['embed_dim'] = width 328 | encoder_config['depth'] = layers 329 | encoder_config['num_heads'] = heads 330 | 331 | # The token predictor 332 | self.token_predictor = TokenPredictor(encoder_config['embed_dim']) 333 | 334 | # The token merger 335 | self.causal_encoder = TokenMerger( 336 | encoder_config['embed_dim'], 337 | num_heads=encoder_config['num_heads'], 338 | depth=encoder_config['depth'], 339 | mlp_ratio=encoder_config['mlp_ratio'], 340 | qkv_bias=encoder_config['qkv_bias'], 341 | qk_scale=encoder_config['qk_scale'], 342 | drop=encoder_config['drop_rate'], 343 | attn_drop=encoder_config['attn_drop_rate'], 344 | ) 345 | 346 | # The code book embeddings 347 | self.quantize = VectorQuantizer(n_embed=n_code, embedding_dim=code_dim) 348 | 349 | # encoder task layer, map the feature to the codebook's dimension 350 | self.encode_task_layer = nn.Sequential( 351 | nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']), 352 | nn.Tanh(), 353 | nn.Linear(encoder_config['embed_dim'], code_dim) # for quantize 354 | ) 355 | 356 | # The vit projection, map the visual feature to LLM's input space 357 | llm_embed_dim = 4096 # LLaMA 7B's embedding dimension: 4096 358 | self.vit_proj = nn.Linear(width, llm_embed_dim) 359 | 360 | self.transform = pth_transforms.Compose([ 361 | pth_transforms.Resize((512, 512), interpolation=InterpolationMode.BICUBIC), 362 | pth_transforms.ToTensor(), 363 | ]) 364 | 365 | def encode_features(self, x): 366 | """ 367 | x: B, 3, H, W 368 | Usage: Given the input image, encode the visual features for the LLM 369 | """ 370 | device = x.device 371 | encoder_features = self.encoder(x, return_all_features=True) # N, 257, D 372 | encoder_features = encoder_features[:,1:,:] 373 | 374 | B, num_patches, _ = encoder_features.shape 375 | mask = torch.ones(B, num_patches, 1, dtype=encoder_features.dtype, device=encoder_features.device) 376 | 377 | # To evalaute the score 378 | pred_score = self.token_predictor(encoder_features.to(torch.float32), mask).reshape(B, -1, 2) 379 | # Sample from the score distribution 380 | hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0] # [N, num_patches] 381 | 382 | # Update the existed features from dropped tokens (To remain the information flow) 383 | updated_features = self.causal_encoder(encoder_features, hard_keep_decision) 384 | updated_features = self.vit_proj(updated_features) # [bs, 256, 4096] 385 | 386 | B, N, C = updated_features.shape 387 | index_select = hard_keep_decision.long() 388 | 389 | token_num = index_select.sum(dim=-1) 390 | index_select = index_select.bool() 391 | 392 | remained_token = torch.masked_select(updated_features, index_select[:,:,None]) 393 | remained_token = remained_token.reshape(-1, C) # [Num Patch] 394 | remained_token_list = torch.split(remained_token, token_num.tolist()) # [bs] 395 | remained_token_list = list(remained_token_list) 396 | 397 | return remained_token_list 398 | 399 | def tokenize_image(self, x_tensor, add_special=False): 400 | # x_tensor: [bs, 3, h, w] 401 | feature_targets = self.encoder(x_tensor, return_all_features=True) # N, 257, D 402 | encoder_features = feature_targets[:,1:,:] 403 | 404 | B, num_patches, _ = encoder_features.shape 405 | mask = torch.ones(B, num_patches, 1, dtype=encoder_features.dtype, device=encoder_features.device) 406 | 407 | pred_score = self.token_predictor(encoder_features.to(torch.float32), mask).reshape(B, -1, 2) 408 | 409 | # Sample from the score distribution 410 | hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0] # [N, num_patches] 411 | 412 | # Update the existed features from dropped tokens (To remain the information flow) 413 | updated_features = self.causal_encoder(encoder_features, hard_keep_decision) 414 | 415 | B, N, C = updated_features.shape 416 | index_select = hard_keep_decision.long() 417 | token_nums = index_select.sum(dim=-1) 418 | index_select = index_select.bool() 419 | remained_token = torch.masked_select(updated_features, index_select[:,:,None]).reshape(-1, C) # [Num Patch] 420 | 421 | to_quantizer_features = self.encode_task_layer(remained_token.type_as(self.encode_task_layer[-1].weight)) 422 | quantize, embed_ind = self.quantize.tokenize(to_quantizer_features) 423 | embed_ind = embed_ind + 32002 424 | embed_ind_list = torch.split(embed_ind, token_nums.tolist(), dim=0) 425 | 426 | if add_special: 427 | # If pad the special image start and end tokens, default is False 428 | output_embed_ind = [] 429 | image_special = torch.as_tensor([32000, 32001], dtype=torch.long).to(x_tensor.device) 430 | for ele in embed_ind_list: 431 | output_embed_ind.append(torch.cat([image_special[:1], ele, image_special[1:]])) 432 | return output_embed_ind 433 | 434 | return embed_ind_list 435 | 436 | 437 | def build_dynamic_tokenizer(model_path='', use_xformers=False, for_understanding=False): 438 | model = DynamicVisualTokenizer(model_path=model_path, use_xformers=use_xformers) 439 | weight_path = os.path.join(model_path, 'visual_tokenizer', 'tokenizer_encoder.bin') 440 | print(f"Load visual tokenizer encoder weight from {weight_path}") 441 | state_dict = torch.load(weight_path, map_location="cpu") 442 | model.load_state_dict(state_dict, strict=False) 443 | 444 | if for_understanding: 445 | # For Understanding, the LaVIT use the continuous visual features, 446 | # so needs to load the token merger weight trained with LLM 447 | visual_weight_path = os.path.join(model_path, 'language_model', 'visual_weight.bin') 448 | print(f"For multi-modal understanding, Load visual tokenizer weight from {visual_weight_path}") 449 | state_dict = torch.load(visual_weight_path, map_location="cpu") 450 | model.load_state_dict(state_dict, strict=False) 451 | 452 | return model 453 | -------------------------------------------------------------------------------- /Pigeon/models/openai_clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | import pdb 10 | 11 | class Bottleneck(nn.Module): 12 | expansion = 4 13 | 14 | def __init__(self, inplanes, planes, stride=1): 15 | super().__init__() 16 | 17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 25 | 26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 28 | 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = None 31 | self.stride = stride 32 | 33 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 35 | self.downsample = nn.Sequential(OrderedDict([ 36 | ("-1", nn.AvgPool2d(stride)), 37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 38 | ("1", nn.BatchNorm2d(planes * self.expansion)) 39 | ])) 40 | 41 | def forward(self, x: torch.Tensor): 42 | identity = x 43 | 44 | out = self.relu(self.bn1(self.conv1(x))) 45 | out = self.relu(self.bn2(self.conv2(out))) 46 | out = self.avgpool(out) 47 | out = self.bn3(self.conv3(out)) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | return out 55 | 56 | 57 | class AttentionPool2d(nn.Module): 58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 59 | super().__init__() 60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 61 | self.k_proj = nn.Linear(embed_dim, embed_dim) 62 | self.q_proj = nn.Linear(embed_dim, embed_dim) 63 | self.v_proj = nn.Linear(embed_dim, embed_dim) 64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 65 | self.num_heads = num_heads 66 | 67 | def forward(self, x, return_all_tokens=False): 68 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 71 | x, _ = F.multi_head_attention_forward( 72 | query=x, key=x, value=x, 73 | embed_dim_to_check=x.shape[-1], 74 | num_heads=self.num_heads, 75 | q_proj_weight=self.q_proj.weight, 76 | k_proj_weight=self.k_proj.weight, 77 | v_proj_weight=self.v_proj.weight, 78 | in_proj_weight=None, 79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 80 | bias_k=None, 81 | bias_v=None, 82 | add_zero_attn=False, 83 | dropout_p=0, 84 | out_proj_weight=self.c_proj.weight, 85 | out_proj_bias=self.c_proj.bias, 86 | use_separate_proj_weight=True, 87 | training=self.training, 88 | need_weights=False 89 | ) 90 | if return_all_tokens: 91 | return x 92 | else: 93 | return x[0] 94 | 95 | 96 | class ModifiedResNet(nn.Module): 97 | """ 98 | A ResNet class that is similar to torchvision's but contains the following changes: 99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 100 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 101 | - The final pooling layer is a QKV attention instead of an average pool 102 | """ 103 | 104 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 105 | super().__init__() 106 | self.output_dim = output_dim 107 | self.input_resolution = input_resolution 108 | 109 | # the 3-layer stem 110 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(width // 2) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.avgpool = nn.AvgPool2d(2) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x, return_side_out=False, return_all_tokens=False): 139 | def stem(x): 140 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 141 | x = self.relu(bn(conv(x))) 142 | x = self.avgpool(x) 143 | return x 144 | out = [] 145 | x = x.type(self.conv1.weight.dtype) 146 | x = stem(x) 147 | x = self.layer1(x) 148 | if return_side_out: 149 | out.append(x) 150 | x = self.layer2(x) 151 | if return_side_out: 152 | out.append(x) 153 | x = self.layer3(x) 154 | if return_side_out: 155 | out.append(x) 156 | x = self.layer4(x) 157 | if return_side_out: 158 | out.append(x) 159 | x = self.attnpool(x, return_all_tokens) 160 | out.append(x) 161 | if len(out) == 1: 162 | return x 163 | else: 164 | return out 165 | 166 | 167 | class LayerNorm(nn.LayerNorm): 168 | """Subclass torch's LayerNorm to handle fp16.""" 169 | 170 | def forward(self, x: torch.Tensor): 171 | orig_type = x.dtype 172 | ret = super().forward(x.type(torch.float32)) 173 | return ret.type(orig_type) 174 | 175 | 176 | class QuickGELU(nn.Module): 177 | def forward(self, x: torch.Tensor): 178 | return x * torch.sigmoid(1.702 * x) 179 | 180 | 181 | class ResidualAttentionBlock(nn.Module): 182 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 183 | super().__init__() 184 | 185 | self.attn = nn.MultiheadAttention(d_model, n_head) 186 | self.ln_1 = LayerNorm(d_model) 187 | self.mlp = nn.Sequential(OrderedDict([ 188 | ("c_fc", nn.Linear(d_model, d_model * 4)), 189 | ("gelu", QuickGELU()), 190 | ("c_proj", nn.Linear(d_model * 4, d_model)) 191 | ])) 192 | self.ln_2 = LayerNorm(d_model) 193 | self.attn_mask = attn_mask 194 | 195 | def attention(self, x: torch.Tensor): 196 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 197 | # pdb.set_trace() 198 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 199 | 200 | def forward(self, x: torch.Tensor): 201 | x = x + self.attention(self.ln_1(x)) 202 | x = x + self.mlp(self.ln_2(x)) 203 | return x 204 | 205 | 206 | class Transformer(nn.Module): 207 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 208 | super().__init__() 209 | self.width = width 210 | self.layers = layers 211 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 212 | 213 | def forward(self, x: torch.Tensor, return_intermediate_out: bool = False): 214 | if return_intermediate_out: 215 | output = [] 216 | for block in self.resblocks: 217 | x = block(x) 218 | output.append(x) 219 | return output 220 | 221 | return self.resblocks(x) 222 | 223 | 224 | class VisionTransformer(nn.Module): 225 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 226 | super().__init__() 227 | self.input_resolution = input_resolution 228 | self.patch_size = patch_size 229 | self.output_dim = output_dim 230 | self.width = width 231 | self.heads = heads 232 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 233 | 234 | scale = width ** -0.5 235 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 236 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 237 | self.ln_pre = LayerNorm(width) 238 | 239 | self.transformer = Transformer(width, layers, heads) 240 | 241 | self.ln_post = LayerNorm(width) 242 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 243 | 244 | def forward(self, x: torch.Tensor, return_all_tokens=False, return_all_final_tokens=False, **kwargs): 245 | 246 | B, nc, w, h = x.shape 247 | 248 | x = self.conv1(x) # shape = [*, width, grid, grid] 249 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 250 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 251 | 252 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 253 | 254 | if x.shape[1] != self.positional_embedding.shape[0]: 255 | x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) 256 | else: 257 | x = x + self.positional_embedding.to(x.dtype) 258 | 259 | x = self.ln_pre(x) 260 | 261 | x = x.permute(1, 0, 2) # NLD -> LND 262 | x = self.transformer(x) 263 | x = x.permute(1, 0, 2) # LND -> NLD 264 | 265 | if return_all_tokens: 266 | x = self.ln_post(x) 267 | return x[:, 1:, :] 268 | 269 | if return_all_final_tokens: 270 | return self.ln_post(x)[:, 1:, :] @ self.proj 271 | 272 | x = self.ln_post(x[:, 0, :]) 273 | 274 | if self.proj is not None: 275 | x = x @ self.proj 276 | 277 | return x 278 | 279 | def interpolate_pos_encoding(self, x, w, h): 280 | # pdb.set_trace() 281 | npatch = x.shape[1] - 1 282 | N = self.positional_embedding.shape[0] - 1 # 256 for large 283 | if npatch == N and w == h: 284 | return self.positional_embedding 285 | class_pos_embed = self.positional_embedding[[0]] 286 | patch_pos_embed = self.positional_embedding[1:] 287 | dim = x.shape[-1] 288 | w0 = w // self.patch_size 289 | h0 = h // self.patch_size 290 | # we add a small number to avoid floating point error in the interpolation 291 | # see discussion at https://github.com/facebookresearch/dino/issues/8 292 | w0, h0 = w0 + 0.1, h0 + 0.1 293 | patch_pos_embed = nn.functional.interpolate( 294 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 295 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 296 | mode='bicubic', 297 | ) 298 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 299 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 300 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 301 | 302 | 303 | class CLIP(nn.Module): 304 | def __init__(self, 305 | embed_dim: int, # 512 306 | # vision 307 | image_resolution: int, # 224 308 | vision_layers: Union[Tuple[int, int, int, int], int], # 12 309 | vision_width: int, # 768 310 | vision_patch_size: int, # 16 311 | # text 312 | context_length: int, # 77 313 | vocab_size: int, # 49408 314 | transformer_width: int, # 512 315 | transformer_heads: int, # 8 316 | transformer_layers: int # 12 317 | ): 318 | super().__init__() 319 | # pdb.set_trace() 320 | self.context_length = context_length 321 | 322 | if isinstance(vision_layers, (tuple, list)): 323 | vision_heads = vision_width * 32 // 64 324 | self.visual = ModifiedResNet( 325 | layers=vision_layers, 326 | output_dim=embed_dim, 327 | heads=vision_heads, 328 | input_resolution=image_resolution, 329 | width=vision_width 330 | ) 331 | else: 332 | vision_heads = vision_width // 64 333 | self.visual = VisionTransformer( 334 | input_resolution=image_resolution, 335 | patch_size=vision_patch_size, 336 | width=vision_width, 337 | layers=vision_layers, 338 | heads=vision_heads, 339 | output_dim=embed_dim 340 | ) 341 | 342 | self.transformer = Transformer( 343 | width=transformer_width, 344 | layers=transformer_layers, 345 | heads=transformer_heads, 346 | attn_mask=self.build_attention_mask() 347 | ) 348 | 349 | self.vocab_size = vocab_size 350 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 351 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 352 | self.ln_final = LayerNorm(transformer_width) 353 | 354 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 355 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 356 | 357 | self.initialize_parameters() 358 | 359 | def initialize_parameters(self): 360 | nn.init.normal_(self.token_embedding.weight, std=0.02) 361 | nn.init.normal_(self.positional_embedding, std=0.01) 362 | 363 | if isinstance(self.visual, ModifiedResNet): 364 | if self.visual.attnpool is not None: 365 | std = self.visual.attnpool.c_proj.in_features ** -0.5 366 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 367 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 368 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 369 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 370 | 371 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 372 | for name, param in resnet_block.named_parameters(): 373 | if name.endswith("bn3.weight"): 374 | nn.init.zeros_(param) 375 | 376 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 377 | attn_std = self.transformer.width ** -0.5 378 | fc_std = (2 * self.transformer.width) ** -0.5 379 | for block in self.transformer.resblocks: 380 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 381 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 382 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 383 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 384 | 385 | if self.text_projection is not None: 386 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 387 | 388 | def build_attention_mask(self): 389 | # lazily create causal attention mask, with full attention between the vision tokens 390 | # pytorch uses additive attention mask; fill with -inf 391 | mask = torch.empty(self.context_length, self.context_length) 392 | mask.fill_(float("-inf")) 393 | mask.triu_(1) # zero out the lower diagonal 394 | return mask 395 | 396 | @property 397 | def dtype(self): 398 | return self.visual.conv1.weight.dtype 399 | 400 | def encode_image(self, image, return_side_out=False, return_all_tokens=False, return_all_final_tokens=False, **kwargs): 401 | return self.visual(image.type(self.dtype), return_all_tokens, return_all_final_tokens, **kwargs) 402 | 403 | def encode_text(self, text, return_all_tokens=False, return_patch_tokens=False): 404 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 405 | 406 | x = x + self.positional_embedding.type(self.dtype) 407 | x = x.permute(1, 0, 2) # NLD -> LND 408 | x = self.transformer(x) 409 | x = x.permute(1, 0, 2) # LND -> NLD 410 | x = self.ln_final(x).type(self.dtype) 411 | 412 | if return_patch_tokens: 413 | return x 414 | # x.shape = [batch_size, n_ctx, transformer.width] 415 | # take features from the eot embedding (eot_token is the highest number in each sequence) 416 | if return_all_tokens: 417 | # pdb.set_trace() 418 | x = x @ self.text_projection 419 | else: 420 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 421 | return x 422 | 423 | def forward(self, image, text): 424 | image_features = self.encode_image(image) 425 | text_features = self.encode_text(text) 426 | 427 | # normalized features 428 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 429 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 430 | 431 | # cosine similarity as logits 432 | logit_scale = self.logit_scale.exp() 433 | logits_per_image = logit_scale * image_features @ text_features.t() 434 | logits_per_text = logits_per_image.t() 435 | 436 | # shape = [global_batch_size, global_batch_size] 437 | return logits_per_image, logits_per_text 438 | 439 | 440 | def convert_weights(model: nn.Module): 441 | """Convert applicable model parameters to fp16""" 442 | 443 | def _convert_weights_to_fp16(l): 444 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 445 | l.weight.data = l.weight.data.half() 446 | if l.bias is not None: 447 | l.bias.data = l.bias.data.half() 448 | 449 | if isinstance(l, nn.MultiheadAttention): 450 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 451 | tensor = getattr(l, attr) 452 | if tensor is not None: 453 | tensor.data = tensor.data.half() 454 | 455 | for name in ["text_projection", "proj"]: 456 | if hasattr(l, name): 457 | attr = getattr(l, name) 458 | if attr is not None: 459 | attr.data = attr.data.half() 460 | 461 | model.apply(_convert_weights_to_fp16) 462 | 463 | 464 | def build_model(state_dict: dict): 465 | vit = "visual.proj" in state_dict 466 | 467 | if vit: 468 | vision_width = state_dict["visual.conv1.weight"].shape[0] 469 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 470 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 471 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 472 | image_resolution = vision_patch_size * grid_size 473 | else: 474 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 475 | vision_layers = tuple(counts) 476 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 477 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 478 | vision_patch_size = None 479 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 480 | image_resolution = output_width * 32 481 | 482 | embed_dim = state_dict["text_projection"].shape[1] 483 | context_length = state_dict["positional_embedding"].shape[0] 484 | vocab_size = state_dict["token_embedding.weight"].shape[0] 485 | transformer_width = state_dict["ln_final.weight"].shape[0] 486 | transformer_heads = transformer_width // 64 487 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 488 | 489 | model = CLIP( 490 | embed_dim, 491 | image_resolution, vision_layers, vision_width, vision_patch_size, 492 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 493 | ) 494 | 495 | for key in ["input_resolution", "context_length", "vocab_size"]: 496 | if key in state_dict: 497 | del state_dict[key] 498 | 499 | convert_weights(model) 500 | model.load_state_dict(state_dict) 501 | return model.eval() 502 | --------------------------------------------------------------------------------