├── GUI_Vid ├── models │ ├── __init__.py │ ├── bert │ │ ├── __init__.py │ │ └── builder.py │ ├── blip2 │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── vit.cpython-39.pyc │ │ │ ├── blip2.cpython-39.pyc │ │ │ ├── Qformer.cpython-39.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── modeling_llama.cpython-39.pyc │ │ ├── builder.py │ │ ├── blip2.py │ │ └── utils.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── videochat2_it.cpython-311.pyc │ │ └── videochat2_it.cpython-39.pyc │ ├── utils.py │ ├── videochat2_pt.py │ └── videochat2_it.py ├── utils │ ├── __pycache__ │ │ ├── config.cpython-39.pyc │ │ ├── config.cpython-311.pyc │ │ ├── easydict.cpython-39.pyc │ │ ├── distributed.cpython-39.pyc │ │ └── easydict.cpython-311.pyc │ ├── config_utils.py │ ├── scheduler.py │ ├── easydict.py │ ├── distributed.py │ ├── optimizer.py │ ├── config.py │ ├── logger.py │ └── basic_utils.py ├── dataset │ ├── __pycache__ │ │ ├── utils.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── dataloader.cpython-39.pyc │ │ ├── it_dataset.cpython-39.pyc │ │ ├── pt_dataset.cpython-39.pyc │ │ ├── video_utils.cpython-39.pyc │ │ ├── base_dataset.cpython-39.pyc │ │ └── video_transforms.cpython-39.pyc │ ├── dataloader.py │ ├── base_dataset.py │ ├── it_dataset.py │ ├── pt_dataset.py │ ├── video_utils.py │ ├── utils.py │ ├── __init__.py │ └── video_transforms.py ├── configs │ ├── model.py │ ├── config_bert.json │ ├── config.json │ ├── data.py │ └── instruction_data.py ├── requirements.txt ├── scripts │ ├── run_7b_stage2.sh │ ├── run_7b_stage3.sh │ ├── run_7b_stage1.sh │ ├── katna.py │ ├── config_7b_stage1.py │ ├── config_7b_stage2.py │ └── config_7b_stage3.py └── demo_local.py ├── Figures ├── radar.jpg ├── Gui_icon.png └── GUI_overview.png └── Readme.md /GUI_Vid/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GUI_Vid/models/bert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Figures/radar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/Figures/radar.jpg -------------------------------------------------------------------------------- /Figures/Gui_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/Figures/Gui_icon.png -------------------------------------------------------------------------------- /Figures/GUI_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/Figures/GUI_overview.png -------------------------------------------------------------------------------- /GUI_Vid/utils/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/utils/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/utils/__pycache__/config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/utils/__pycache__/config.cpython-311.pyc -------------------------------------------------------------------------------- /GUI_Vid/utils/__pycache__/easydict.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/utils/__pycache__/easydict.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/__pycache__/vit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/blip2/__pycache__/vit.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/utils/__pycache__/distributed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/utils/__pycache__/distributed.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/utils/__pycache__/easydict.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/utils/__pycache__/easydict.cpython-311.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/it_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/it_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/pt_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/pt_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/video_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/video_utils.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/__pycache__/blip2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/blip2/__pycache__/blip2.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/base_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/base_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/__pycache__/videochat2_it.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/__pycache__/videochat2_it.cpython-311.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/__pycache__/videochat2_it.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/__pycache__/videochat2_it.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/__pycache__/Qformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/blip2/__pycache__/Qformer.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/blip2/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/dataset/__pycache__/video_transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/dataset/__pycache__/video_transforms.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/__pycache__/modeling_llama.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dongping-Chen/GUI-World/HEAD/GUI_Vid/models/blip2/__pycache__/modeling_llama.cpython-39.pyc -------------------------------------------------------------------------------- /GUI_Vid/configs/model.py: -------------------------------------------------------------------------------- 1 | TextEncoders = dict() 2 | TextEncoders["bert"] = dict( 3 | name="bert_base", 4 | pretrained="bert-base-uncased", 5 | config="configs/config_bert.json", 6 | d_model=768, 7 | fusion_layer=9, 8 | ) -------------------------------------------------------------------------------- /GUI_Vid/requirements.txt: -------------------------------------------------------------------------------- 1 | apex==0.9.10dev 2 | av==10.0.0 3 | decord==0.6.0 4 | einops==0.6.1 5 | fvcore==0.1.5.post20221221 6 | gradio==3.35.0 7 | imageio==2.27.0 8 | iopath==0.1.10 9 | mmcv==2.0.0 10 | numpy==1.23.5 11 | omegaconf==2.3.0 12 | opencv_python==4.7.0.72 13 | pandas==1.5.3 14 | Pillow==9.5.0 15 | psutil==5.9.4 16 | PyYAML==6.0 17 | scipy==1.10.1 18 | termcolor==2.3.0 19 | timm==0.6.12 20 | tqdm==4.64.1 21 | transformers==4.28.1 22 | sentencepiece==0.1.99 23 | wandb==0.14.0 24 | peft==0.3.0 25 | # flash_attn==1.0.4 26 | -------------------------------------------------------------------------------- /GUI_Vid/configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 9, 20 | "encoder_width": 768, 21 | "cross_module": "ca" 22 | } 23 | -------------------------------------------------------------------------------- /GUI_Vid/scripts/run_7b_stage2.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | 10 | NNODE=4 11 | NUM_GPUS=8 12 | MASTER_NODE='SH-IDC1-10-140-1-1' 13 | 14 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} \ 15 | --rdzv_endpoint=${MASTER_NODE}:10068 \ 16 | --rdzv_backend=c10d \ 17 | tasks/train_pt.py \ 18 | $(dirname $0)/config_7b_stage2.py \ 19 | output_dir ${OUTPUT_DIR} 20 | -------------------------------------------------------------------------------- /GUI_Vid/scripts/run_7b_stage3.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | 10 | NNODE=4 11 | NUM_GPUS=8 12 | MASTER_NODE='SH-IDC1-10-140-1-1' 13 | 14 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} \ 15 | --rdzv_endpoint=${MASTER_NODE}:10068 \ 16 | --rdzv_backend=c10d \ 17 | tasks/train_it.py \ 18 | $(dirname $0)/config_7b_stage3.py \ 19 | output_dir ${OUTPUT_DIR} 20 | -------------------------------------------------------------------------------- /GUI_Vid/scripts/run_7b_stage1.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | 10 | NNODE=4 11 | NUM_GPUS=8 12 | MASTER_NODE='SH-IDC1-10-140-1-1' 13 | 14 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} \ 15 | --rdzv_endpoint=${MASTER_NODE}:10068 \ 16 | --rdzv_backend=c10d \ 17 | tasks/train_qformer.py \ 18 | $(dirname $0)/config_7b_stage1.py \ 19 | output_dir ${OUTPUT_DIR} 20 | -------------------------------------------------------------------------------- /GUI_Vid/scripts/katna.py: -------------------------------------------------------------------------------- 1 | from Katna.video import Video 2 | from Katna.writer import KeyFrameDiskWriter 3 | import os 4 | import ntpath 5 | 6 | # For windows, the below if condition is must. 7 | if __name__ == "__main__": 8 | 9 | #instantiate the video class 10 | vd = Video() 11 | 12 | #number of key-frame images to be extracted 13 | no_of_frames_to_return = 10 14 | 15 | #Input Video directory path 16 | #All .mp4 and .mov files inside this directory will be used for keyframe extraction) 17 | videos_dir_path = "" 18 | 19 | diskwriter = KeyFrameDiskWriter(location= videos_dir_path + "selectedframes") 20 | 21 | vd.extract_keyframes_from_videos_dir( 22 | no_of_frames=no_of_frames_to_return, dir_path=videos_dir_path, 23 | writer=diskwriter 24 | ) -------------------------------------------------------------------------------- /GUI_Vid/configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "model_cls": "VideoChat2_it", 4 | "vit_blip_model_path": "/media/sata2/cdp/Ask-Anything/umt_l16_qformer.pth", 5 | "llama_model_path": "/media/sata2/cdp/vicuna-7b-v0", 6 | "videochat2_model_path": "/media/sata2/cdp/Ask-Anything/videochat2_7b_stage3.pth", 7 | "freeze_vit": false, 8 | "freeze_qformer": false, 9 | "max_txt_len": 512, 10 | "low_resource": false, 11 | "vision_encoder": { 12 | "name": "vit_l14", 13 | "img_size": 224, 14 | "patch_size": 16, 15 | "d_model": 1024, 16 | "encoder_embed_dim": 1024, 17 | "encoder_depth": 24, 18 | "encoder_num_heads": 16, 19 | "drop_path_rate": 0.0, 20 | "num_frames": 32, 21 | "tubelet_size": 1, 22 | "use_checkpoint": false, 23 | "checkpoint_num": 0, 24 | "pretrained": "", 25 | "return_index": -2, 26 | "vit_add_ln": true, 27 | "ckpt_num_frame": 4 28 | }, 29 | "num_query_token": 32, 30 | "qformer_hidden_dropout_prob": 0.1, 31 | "qformer_attention_probs_dropout_prob": 0.1, 32 | "qformer_drop_path_rate": 0.2, 33 | "extra_num_query_token": 64, 34 | "qformer_text_input": true, 35 | "system": "", 36 | "start_token": "", 38 | "img_start_token": "", 39 | "img_end_token": "", 40 | "random_shuffle": true, 41 | "use_lora": false, 42 | "lora_r": 16, 43 | "lora_alpha": 32, 44 | "lora_dropout": 0.1 45 | }, 46 | "device": "cuda" 47 | } 48 | -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | 5 | 6 | from .Qformer import BertConfig, BertLMHeadModel 7 | from models.utils import load_temp_embed_with_mismatch 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def build_qformer(num_query_token, vision_width, 13 | qformer_hidden_dropout_prob=0.1, 14 | qformer_attention_probs_dropout_prob=0.1, 15 | drop_path_rate=0., 16 | ): 17 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True) 18 | encoder_config.encoder_width = vision_width 19 | # insert cross-attention layer every other block 20 | encoder_config.add_cross_attention = True 21 | encoder_config.cross_attention_freq = 2 22 | encoder_config.query_length = num_query_token 23 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob 24 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob 25 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_config.num_hidden_layers)] 26 | logger.info(f"Drop_path:{encoder_config.drop_path_list}") 27 | logger.info(encoder_config) 28 | Qformer = BertLMHeadModel.from_pretrained( 29 | "bert-base-uncased", config=encoder_config, local_files_only=True 30 | ) 31 | query_tokens = nn.Parameter( 32 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 33 | ) 34 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 35 | return Qformer, query_tokens 36 | 37 | def interpolate_pos_embed_blip(state_dict, new_model): 38 | if "vision_temp_embed" in state_dict: 39 | vision_temp_embed_new = new_model.state_dict()["vision_temp_embed"] 40 | state_dict["vision_temp_embed"] = load_temp_embed_with_mismatch( 41 | state_dict["vision_temp_embed"], vision_temp_embed_new, add_zero=False 42 | ) 43 | return state_dict 44 | -------------------------------------------------------------------------------- /GUI_Vid/configs/data.py: -------------------------------------------------------------------------------- 1 | import os as __os # add "__" if not want to be exported 2 | from copy import deepcopy as __deepcopy 3 | 4 | data_dir = 'your_annotation_path' 5 | if data_dir is None: 6 | raise ValueError("please set environment `VL_DATA_DIR` before continue") 7 | 8 | data_root = __os.path.join(data_dir, "videos_images") 9 | anno_root_pt = __os.path.join(data_dir, "anno_pretrain") 10 | 11 | # ============== pretraining datasets================= 12 | available_corpus = dict( 13 | # pretraining datasets 14 | cc3m=[ 15 | f"{anno_root_pt}/cc3m_train.json", 16 | f"{data_root}/cc3m", 17 | ], 18 | cc12m=[ 19 | f"{anno_root_pt}/cc12m_train.json", 20 | f"{data_root}/cc12m", 21 | ], 22 | sbu=[ 23 | f"{anno_root_pt}/sbu.json", 24 | f"{data_root}/sbu", 25 | ], 26 | vg=[ 27 | f"{anno_root_pt}/vg.json", 28 | f"{data_root}/vg", 29 | ], 30 | coco=[ 31 | f"{anno_root_pt}/coco.json", 32 | f"{data_root}/coco", 33 | ], 34 | webvid=[ 35 | f"{anno_root_pt}/webvid_train.json", 36 | f"{data_root}/webvid", 37 | "video" 38 | ], 39 | webvid_10m=[ 40 | f"{anno_root_pt}/webvid_10m_train.json", 41 | f"{data_root}/webvid_10m", 42 | "video", 43 | ], 44 | internvid_10m=[ 45 | f"{anno_root_pt}/internvid_10m_train.json", 46 | f"{data_root}/internvid_10m", 47 | "video" 48 | ], 49 | ) 50 | 51 | # composed datasets. 52 | available_corpus["msrvtt_1k_test"] = [ 53 | f"{anno_root_pt}/msrvtt_test1k.json", 54 | f"{data_root}/MSRVTT_Videos", 55 | "video", 56 | ] 57 | 58 | available_corpus["webvid10m_cc14m"] = [ 59 | available_corpus["webvid_10m"], 60 | available_corpus["cc3m"], 61 | available_corpus["cc12m"], 62 | ] 63 | available_corpus["webvid10m_cc14m_plus"] = [ 64 | available_corpus["webvid_10m"], 65 | available_corpus["cc3m"], 66 | available_corpus["coco"], 67 | available_corpus["vg"], 68 | available_corpus["sbu"], 69 | available_corpus["cc12m"], 70 | available_corpus["internvid_10m"], 71 | ] -------------------------------------------------------------------------------- /GUI_Vid/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from os.path import dirname, join 5 | 6 | from utils.config import Config 7 | from utils.distributed import init_distributed_mode, is_main_process 8 | from utils.logger import setup_logger 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_config(): 14 | """Conbine yaml config and command line config with OmegaConf. 15 | Also converts types, e.g., `'None'` (str) --> `None` (None) 16 | """ 17 | config = Config.get_config() 18 | if config.debug: 19 | config.wandb.enable = False 20 | return config 21 | 22 | 23 | def setup_evaluate_config(config): 24 | """setup evaluation default settings, e.g., disable wandb""" 25 | assert config.evaluate 26 | config.wandb.enable = False 27 | if config.output_dir is None: 28 | config.output_dir = join(dirname(config.pretrained_path), "eval") 29 | return config 30 | 31 | 32 | def setup_output_dir(output_dir, excludes=["code"]): 33 | """ensure not overwritting an exisiting/non-empty output dir""" 34 | if not os.path.exists(output_dir): 35 | os.makedirs(output_dir, exist_ok=False) 36 | else: 37 | existing_dirs_files = os.listdir(output_dir) # list 38 | remaining = set(existing_dirs_files) - set(excludes) 39 | remaining = [e for e in remaining if "slurm" not in e] 40 | remaining = [e for e in remaining if ".out" not in e] 41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}" 42 | logger.warn(f"remaining dirs or files: {remaining}") 43 | 44 | 45 | def setup_main(): 46 | """ 47 | Setup config, logger, output_dir, etc. 48 | Shared for pretrain and all downstream tasks. 49 | """ 50 | config = setup_config() 51 | if hasattr(config, "evaluate") and config.evaluate: 52 | config = setup_evaluate_config(config) 53 | init_distributed_mode(config) 54 | 55 | if is_main_process(): 56 | setup_output_dir(config.output_dir, excludes=["code"]) 57 | setup_logger(output=config.output_dir, color=True, name="vindlu") 58 | logger.info(f"config: {Config.pretty_text(config)}") 59 | Config.dump(config, os.path.join(config.output_dir, "config.json")) 60 | return config 61 | -------------------------------------------------------------------------------- /GUI_Vid/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process 4 | import random 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class MetaLoader(object): 11 | """ wraps multiple data loader """ 12 | def __init__(self, name2loader): 13 | """Iterates over multiple dataloaders, it ensures all processes 14 | work on data from the same dataloader. This loader will end when 15 | the shorter dataloader raises StopIteration exception. 16 | 17 | loaders: Dict, {name: dataloader} 18 | """ 19 | self.name2loader = name2loader 20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()} 21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} 22 | index2name = {v: k for k, v in name2index.items()} 23 | 24 | iter_order = [] 25 | for n, l in name2loader.items(): 26 | iter_order.extend([name2index[n]]*len(l)) 27 | 28 | random.shuffle(iter_order) 29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 30 | 31 | # sync 32 | if is_dist_avail_and_initialized(): 33 | # make sure all processes have the same order so that 34 | # each step they will have data from the same loader 35 | dist.broadcast(iter_order, src=0) 36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] 37 | 38 | logger.info(str(self)) 39 | 40 | def __str__(self): 41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] 42 | for idx, (name, loader) in enumerate(self.name2loader.items()): 43 | output.append( 44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " 45 | ) 46 | return "\n".join(output) 47 | 48 | def __len__(self): 49 | return len(self.iter_order) 50 | 51 | def __iter__(self): 52 | """ this iterator will run indefinitely """ 53 | for name in self.iter_order: 54 | _iter = self.name2iter[name] 55 | batch = next(_iter) 56 | yield name, batch 57 | -------------------------------------------------------------------------------- /GUI_Vid/models/bert/builder.py: -------------------------------------------------------------------------------- 1 | from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel 2 | 3 | import logging 4 | logger = logging.getLogger(__name__) 5 | 6 | def build_bert(model_config, pretrain, checkpoint): 7 | """build text encoder. 8 | 9 | Args: 10 | model_config (dict): model config. 11 | pretrain (bool): Whether to do pretrain or finetuning. 12 | checkpoint (bool): whether to do gradient_checkpointing. 13 | 14 | Returns: TODO 15 | 16 | """ 17 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 18 | bert_config.encoder_width = model_config.vision_encoder.d_model 19 | bert_config.gradient_checkpointing = checkpoint 20 | bert_config.fusion_layer = model_config.text_encoder.fusion_layer 21 | 22 | if not model_config.multimodal.enable: 23 | bert_config.fusion_layer = bert_config.num_hidden_layers 24 | 25 | if pretrain: 26 | text_encoder, loading_info = BertForMaskedLM.from_pretrained( 27 | model_config.text_encoder.pretrained, 28 | config=bert_config, 29 | output_loading_info=True, 30 | ) 31 | else: 32 | text_encoder, loading_info = BertModel.from_pretrained( 33 | model_config.text_encoder.pretrained, 34 | config=bert_config, 35 | add_pooling_layer=False, 36 | output_loading_info=True, 37 | ) 38 | 39 | return text_encoder 40 | 41 | 42 | def build_bert_decoder(model_config, checkpoint): 43 | """build text decoder the same as the multimodal encoder. 44 | 45 | Args: 46 | model_config (dict): model config. 47 | pretrain (bool): Whether to do pretrain or finetuning. 48 | checkpoint (bool): whether to do gradient_checkpointing. 49 | 50 | Returns: TODO 51 | 52 | """ 53 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 54 | bert_config.encoder_width = model_config.vision_encoder.d_model 55 | bert_config.gradient_checkpointing = checkpoint 56 | 57 | bert_config.fusion_layer = 0 58 | bert_config.num_hidden_layers = ( 59 | bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer 60 | ) 61 | 62 | text_decoder, loading_info = BertLMHeadModel.from_pretrained( 63 | model_config.text_encoder.pretrained, 64 | config=bert_config, 65 | output_loading_info=True, 66 | ) 67 | 68 | return text_decoder 69 | -------------------------------------------------------------------------------- /GUI_Vid/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from torch.optim import Optimizer 5 | import math 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def create_scheduler(args, optimizer): 10 | lr_scheduler = None 11 | if args.sched == 'cosine': 12 | lr_scheduler = get_cosine_schedule_with_warmup( 13 | optimizer, 14 | num_warmup_steps=args.num_warmup_steps, 15 | num_training_steps=args.num_training_steps, 16 | num_cycles=0.5, 17 | min_lr_multi=args.min_lr_multi 18 | ) 19 | return lr_scheduler 20 | 21 | 22 | def get_cosine_schedule_with_warmup( 23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 25 | ): 26 | """ 27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py 28 | 29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 31 | initial lr set in the optimizer. 32 | Args: 33 | optimizer ([`~torch.optim.Optimizer`]): 34 | The optimizer for which to schedule the learning rate. 35 | num_warmup_steps (`int`): 36 | The number of steps for the warmup phase. 37 | num_training_steps (`int`): 38 | The total number of training steps. 39 | num_cycles (`float`, *optional*, defaults to 0.5): 40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 41 | following a half-cosine). 42 | min_lr_multi (`float`, *optional*, defaults to 0): 43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. 44 | last_epoch (`int`, *optional*, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | Return: 47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 48 | """ 49 | 50 | def lr_lambda(current_step): 51 | if current_step < num_warmup_steps: 52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) 53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 55 | 56 | return LambdaLR(optimizer, lr_lambda, last_epoch) 57 | -------------------------------------------------------------------------------- /GUI_Vid/configs/instruction_data.py: -------------------------------------------------------------------------------- 1 | import os as __os # add "__" if not want to be exported 2 | from copy import deepcopy as __deepcopy 3 | 4 | anno_root_it = "" 5 | 6 | # ============== pretraining datasets================= 7 | available_corpus = dict( 8 | # Images 9 | mobile_gui_image_concise_caption=[ 10 | f"/media/sata4/final_videos/meta_gui_image/annotations/train_concise_caption_data.json", 11 | "/media/sata4/final_videos/", 12 | ], 13 | mobile_gui_image_detailed_caption=[ 14 | f"/media/sata4/final_videos/meta_gui_image/annotations/train_detailed_caption_data.json", 15 | "/media/sata4/final_videos/", 16 | ], 17 | omniact_image_concise_caption=[ 18 | "/media/sata4/final_videos/omniact/concise_caption_data.json", 19 | "/media/sata4/", 20 | ], 21 | omniact_image_detailed_caption=[ 22 | "/media/sata4/final_videos/omniact/detailed_caption_data.json", 23 | "/media/sata4/", 24 | ], 25 | gui_image_concise_caption=[ 26 | f"{anno_root_it}/v3_new_concise_caption_data.json", 27 | "/media/sata4/", 28 | ], 29 | gui_image_detailed_caption=[ 30 | f"{anno_root_it}/v3_new_detailed_caption_data.json", 31 | "/media/sata4/", 32 | ], 33 | gui_video_caption=[ 34 | f"{anno_root_it}/v3_new_caption_data.json", 35 | "/media/sata4/", 36 | "video" 37 | ], 38 | # Videos 39 | gui_video_short_caption=[ 40 | f"{anno_root_it}/v3_new_short_caption_data.json", 41 | "/media/sata4/", 42 | "video" 43 | ], 44 | gui_video_reasoning=[ 45 | f"{anno_root_it}/v3_new_reasoning_data.json", 46 | "/media/sata4/", 47 | "video" 48 | ], 49 | gui_video_vqa=[ 50 | f"{anno_root_it}/v3_new_vqa_data.json", 51 | "/media/sata4/", 52 | "video" 53 | ], 54 | gui_video_conversation=[ 55 | f"{anno_root_it}/v3_new_conversation_data.json", 56 | "/media/sata4/", 57 | "video" 58 | ] 59 | ) 60 | 61 | 62 | # add mc for clevrer_qa 63 | available_corpus["videochat2_instruction"] = [ 64 | # Images 65 | # available_corpus['mobile_gui_image_concise_caption'], 66 | # available_corpus['mobile_gui_image_detailed_caption'], 67 | # available_corpus['omniact_image_concise_caption'], 68 | # available_corpus['omniact_image_detailed_caption'], 69 | # available_corpus['gui_image_detailed_caption'], 70 | # available_corpus['gui_image_concise_caption'], 71 | 72 | # Videos 73 | available_corpus['gui_video_caption'], 74 | available_corpus['gui_video_short_caption'], 75 | available_corpus['gui_video_reasoning'], 76 | available_corpus['gui_video_vqa'], 77 | available_corpus['gui_video_conversation'] 78 | ] 79 | -------------------------------------------------------------------------------- /GUI_Vid/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | 6 | from dataset.utils import load_image_from_path 7 | 8 | try: 9 | from petrel_client.client import Client 10 | has_client = True 11 | except ImportError: 12 | has_client = False 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ImageVideoBaseDataset(Dataset): 18 | """Base class that implements the image and video loading methods""" 19 | 20 | media_type = "video" 21 | 22 | def __init__(self): 23 | assert self.media_type in ["image", "video", "only_video"] 24 | self.data_root = None 25 | self.anno_list = ( 26 | None # list(dict), each dict contains {"image": str, # image or video path} 27 | ) 28 | self.transform = None 29 | self.video_reader = None 30 | self.num_tries = None 31 | 32 | self.client = None 33 | if has_client: 34 | self.client = Client('~/petreloss.conf') 35 | 36 | def __getitem__(self, index): 37 | raise NotImplementedError 38 | 39 | def __len__(self): 40 | raise NotImplementedError 41 | 42 | def get_anno(self, index): 43 | """obtain the annotation for one media (video or image) 44 | 45 | Args: 46 | index (int): The media index. 47 | 48 | Returns: dict. 49 | - "image": the filename, video also use "image". 50 | - "caption": The caption for this file. 51 | 52 | """ 53 | anno = self.anno_list[index] 54 | if self.data_root is not None: 55 | anno["image"] = os.path.join(self.data_root, anno["image"]) 56 | return anno 57 | 58 | def load_and_transform_media_data(self, index, data_path): 59 | if self.media_type == "image": 60 | return self.load_and_transform_media_data_image(index, data_path) 61 | else: 62 | return self.load_and_transform_media_data_video(index, data_path) 63 | 64 | def load_and_transform_media_data_image(self, index, data_path): 65 | image = load_image_from_path(data_path, client=self.client) 66 | image = self.transform(image) 67 | return image, index 68 | 69 | def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None): 70 | for _ in range(self.num_tries): 71 | try: 72 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 73 | frames, frame_indices, fps = self.video_reader( 74 | data_path, self.num_frames, self.sample_type, 75 | max_num_frames=max_num_frames, client=self.client, clip=clip 76 | ) 77 | except Exception as e: 78 | logger.warning( 79 | f"Caught exception {e} when loading video {data_path}, " 80 | f"randomly sample a new video as replacement" 81 | ) 82 | index = random.randint(0, len(self) - 1) 83 | ann = self.get_anno(index) 84 | data_path = ann["image"] 85 | continue 86 | # shared aug for video frames 87 | frames = self.transform(frames) 88 | if return_fps: 89 | sec = [str(round(f / fps, 1)) for f in frame_indices] 90 | return frames, index, sec 91 | else: 92 | return frames, index 93 | else: 94 | raise RuntimeError( 95 | f"Failed to fetch video after {self.num_tries} tries. " 96 | f"This might indicate that you have many corrupted videos." 97 | ) 98 | -------------------------------------------------------------------------------- /GUI_Vid/scripts/config_7b_stage1.py: -------------------------------------------------------------------------------- 1 | from configs.data import * 2 | from configs.model import * 3 | 4 | # ========================= data ========================== 5 | train_corpus = "webvid10m_cc14m" 6 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 7 | test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"]) 8 | test_types = ["msrvtt_1k_test"] 9 | 10 | num_workers = 6 11 | 12 | stop_key = None 13 | 14 | # ========================= input ========================== 15 | num_frames = 4 16 | num_frames_test = 4 17 | batch_size = 128 18 | max_txt_l = 32 19 | 20 | pre_text = False 21 | 22 | inputs = dict( 23 | image_res=224, 24 | video_input=dict( 25 | num_frames="${num_frames}", 26 | sample_type="rand", 27 | num_frames_test="${num_frames_test}", 28 | sample_type_test="middle", 29 | random_aug=False, 30 | ), 31 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 32 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 33 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 34 | ) 35 | 36 | # ========================= model ========================== 37 | text_enc = "bert" 38 | model = dict( 39 | model_cls="VideoChat2_qformer", 40 | vision_encoder=dict( 41 | name="vit_l14", 42 | img_size=224, 43 | patch_size=16, 44 | d_model=1024, 45 | encoder_embed_dim=1024, 46 | encoder_depth=24, 47 | encoder_num_heads=16, 48 | drop_path_rate=0., 49 | num_frames="${num_frames}", 50 | tubelet_size=1, 51 | use_checkpoint=False, 52 | checkpoint_num=12, 53 | pretrained="your_model_path/l16_25m.pth", 54 | return_index=-2, 55 | ), 56 | text_encoder="${TextEncoders[${text_enc}]}", 57 | vit_add_ln=True, 58 | embed_dim=768, 59 | temp=0.07, 60 | qformer_num_query_tokens=32, 61 | agg_method="mean", 62 | drop_path_rate=0.2, 63 | ) 64 | 65 | criterion = dict( 66 | loss_weight=dict(vtc=1.0, mlm=0.0, vtm=1.0, mvm=0.0, cap=1.0), # 0: disabled. 67 | vtm_hard_neg=True, 68 | vtm_cat_text_cls=True 69 | ) 70 | 71 | optimizer = dict( 72 | opt="adamW", 73 | lr=1e-4, 74 | opt_betas=[0.9, 0.999], # default 75 | weight_decay=0.02, 76 | max_grad_norm=-1, # requires a positive float, use -1 to disable 77 | # use a different lr for some modules, e.g., larger lr for new modules 78 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 79 | ) 80 | 81 | scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.01, warmup_epochs=0.2) 82 | 83 | evaluate = False 84 | deep_fusion = False 85 | evaluation = dict( 86 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 87 | eval_x_only=False, 88 | k_test=128, 89 | eval_offload=True, # offload gpu tensors to cpu to save memory. 90 | ) 91 | 92 | fp16 = True 93 | gradient_checkpointing = True 94 | 95 | # ========================= wandb ========================== 96 | wandb = dict( 97 | enable=False, 98 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 99 | project="videochat2", # setup in your command line 100 | ) 101 | dist_url = "env://" 102 | device = "cuda" 103 | mode = "pt" 104 | 105 | # ========================= others ========================== 106 | output_dir = None # output dir 107 | resume = False # if True, load optimizer and scheduler states as well 108 | debug = False 109 | log_freq = 100 110 | seed = 42 111 | 112 | save_latest = True 113 | auto_resume = True 114 | pretrained_path = "" # path to pretrained model weights, for resume only? 115 | -------------------------------------------------------------------------------- /GUI_Vid/utils/easydict.py: -------------------------------------------------------------------------------- 1 | class EasyDict(dict): 2 | """ 3 | Get attributes 4 | 5 | >>> d = EasyDict({'foo':3}) 6 | >>> d['foo'] 7 | 3 8 | >>> d.foo 9 | 3 10 | >>> d.bar 11 | Traceback (most recent call last): 12 | ... 13 | AttributeError: 'EasyDict' object has no attribute 'bar' 14 | 15 | Works recursively 16 | 17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 18 | >>> isinstance(d.bar, dict) 19 | True 20 | >>> d.bar.x 21 | 1 22 | 23 | Bullet-proof 24 | 25 | >>> EasyDict({}) 26 | {} 27 | >>> EasyDict(d={}) 28 | {} 29 | >>> EasyDict(None) 30 | {} 31 | >>> d = {'a': 1} 32 | >>> EasyDict(**d) 33 | {'a': 1} 34 | 35 | Set attributes 36 | 37 | >>> d = EasyDict() 38 | >>> d.foo = 3 39 | >>> d.foo 40 | 3 41 | >>> d.bar = {'prop': 'value'} 42 | >>> d.bar.prop 43 | 'value' 44 | >>> d 45 | {'foo': 3, 'bar': {'prop': 'value'}} 46 | >>> d.bar.prop = 'newer' 47 | >>> d.bar.prop 48 | 'newer' 49 | 50 | 51 | Values extraction 52 | 53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 54 | >>> isinstance(d.bar, list) 55 | True 56 | >>> from operator import attrgetter 57 | >>> map(attrgetter('x'), d.bar) 58 | [1, 3] 59 | >>> map(attrgetter('y'), d.bar) 60 | [2, 4] 61 | >>> d = EasyDict() 62 | >>> d.keys() 63 | [] 64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 65 | >>> d.foo 66 | 3 67 | >>> d.bar.x 68 | 1 69 | 70 | Still like a dict though 71 | 72 | >>> o = EasyDict({'clean':True}) 73 | >>> o.items() 74 | [('clean', True)] 75 | 76 | And like a class 77 | 78 | >>> class Flower(EasyDict): 79 | ... power = 1 80 | ... 81 | >>> f = Flower() 82 | >>> f.power 83 | 1 84 | >>> f = Flower({'height': 12}) 85 | >>> f.height 86 | 12 87 | >>> f['power'] 88 | 1 89 | >>> sorted(f.keys()) 90 | ['height', 'power'] 91 | 92 | update and pop items 93 | >>> d = EasyDict(a=1, b='2') 94 | >>> e = EasyDict(c=3.0, a=9.0) 95 | >>> d.update(e) 96 | >>> d.c 97 | 3.0 98 | >>> d['c'] 99 | 3.0 100 | >>> d.get('c') 101 | 3.0 102 | >>> d.update(a=4, b=4) 103 | >>> d.b 104 | 4 105 | >>> d.pop('a') 106 | 4 107 | >>> d.a 108 | Traceback (most recent call last): 109 | ... 110 | AttributeError: 'EasyDict' object has no attribute 'a' 111 | """ 112 | 113 | def __init__(self, d=None, **kwargs): 114 | if d is None: 115 | d = {} 116 | if kwargs: 117 | d.update(**kwargs) 118 | for k, v in d.items(): 119 | setattr(self, k, v) 120 | # Class attributes 121 | for k in self.__class__.__dict__.keys(): 122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): 123 | setattr(self, k, getattr(self, k)) 124 | 125 | def __setattr__(self, name, value): 126 | if isinstance(value, (list, tuple)): 127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value] 128 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 129 | value = self.__class__(value) 130 | super(EasyDict, self).__setattr__(name, value) 131 | super(EasyDict, self).__setitem__(name, value) 132 | 133 | __setitem__ = __setattr__ 134 | 135 | def update(self, e=None, **f): 136 | d = e or dict() 137 | d.update(f) 138 | for k in d: 139 | setattr(self, k, d[k]) 140 | 141 | def pop(self, k, d=None): 142 | if hasattr(self, k): 143 | delattr(self, k) 144 | return super(EasyDict, self).pop(k, d) 145 | 146 | 147 | if __name__ == "__main__": 148 | import doctest 149 | 150 | -------------------------------------------------------------------------------- /GUI_Vid/scripts/config_7b_stage2.py: -------------------------------------------------------------------------------- 1 | from configs.data import * 2 | 3 | # ========================= data ========================== 4 | train_corpus = "webvid10m_cc14m_plus" 5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 6 | test_file = dict() 7 | test_types = [] 8 | num_workers = 6 9 | 10 | stop_key = None 11 | 12 | # ========================= input ========================== 13 | num_frames = 8 14 | num_frames_test = 8 15 | batch_size = 4 16 | max_txt_l = 512 17 | 18 | pre_text = False 19 | 20 | inputs = dict( 21 | image_res=224, 22 | video_input=dict( 23 | num_frames="${num_frames}", 24 | sample_type="rand", 25 | num_frames_test="${num_frames_test}", 26 | sample_type_test="middle", 27 | random_aug=False, 28 | ), 29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 30 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 32 | ) 33 | 34 | # ========================= model ========================== 35 | model = dict( 36 | model_cls="VideoChat2_pt", 37 | vit_blip_model_path="your_model_path/umt_l16_qformer.pth", 38 | llama_model_path="your_model_path/vicuna-7b-v0", 39 | freeze_vit=False, 40 | freeze_qformer=False, 41 | max_txt_len="${max_txt_l}", 42 | # vit 43 | low_resource=False, 44 | vision_encoder=dict( 45 | name="vit_l14", 46 | img_size=224, 47 | patch_size=16, 48 | d_model=1024, 49 | encoder_embed_dim=1024, 50 | encoder_depth=24, 51 | encoder_num_heads=16, 52 | drop_path_rate=0., 53 | num_frames="${num_frames}", 54 | tubelet_size=1, 55 | use_checkpoint=False, 56 | checkpoint_num=0, 57 | pretrained="", 58 | return_index=-2, 59 | vit_add_ln=True, 60 | ), 61 | # prompt 62 | prompt_path="prompts/concise_description.txt", 63 | img_prompt_path="prompts/concise_image_description.txt", 64 | prompt_template="###Human: {} ###Assistant: ", 65 | end_sym="###", 66 | # qformer 67 | num_query_token=32, 68 | qformer_hidden_dropout_prob=0.1, 69 | qformer_attention_probs_dropout_prob=0.1, 70 | qformer_drop_path_rate=0.2, 71 | extra_num_query_token=64, 72 | # debug=True, 73 | ) 74 | 75 | optimizer = dict( 76 | opt="adamW", 77 | lr=1e-4, 78 | opt_betas=[0.9, 0.999], # default 79 | weight_decay=0.02, 80 | max_grad_norm=-1, # requires a positive float, use -1 to disable 81 | # use a different lr for some modules, e.g., larger lr for new modules 82 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 83 | ) 84 | 85 | scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.2) 86 | 87 | evaluate = False 88 | deep_fusion = False 89 | evaluation = dict( 90 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 91 | eval_x_only=False, 92 | k_test=128, 93 | eval_offload=True, # offload gpu tensors to cpu to save memory. 94 | ) 95 | 96 | fp16 = True 97 | gradient_checkpointing = True 98 | 99 | # ========================= wandb ========================== 100 | wandb = dict( 101 | enable=False, 102 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 103 | project="videochat2", # setup in your command line 104 | ) 105 | dist_url = "env://" 106 | device = "cuda" 107 | mode = "pt" 108 | 109 | # ========================= others ========================== 110 | output_dir = None # output dir 111 | resume = False # if True, load optimizer and scheduler states as well 112 | debug = False 113 | log_freq = 100 114 | seed = 42 115 | 116 | save_latest = True 117 | auto_resume = True 118 | pretrained_path = "" # path to pretrained model weights, for resume only? 119 | -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import contextlib 8 | import os 9 | import logging 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from .Qformer import BertConfig, BertLMHeadModel 15 | from .vit import build_vit 16 | from transformers import BertTokenizer 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Blip2Base(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @classmethod 26 | def init_tokenizer(cls, truncation_side="right"): 27 | # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side, local_files_only=True) 28 | tokenizer = BertTokenizer.from_pretrained("/media/sata1/cdp/transformer/bert-base-uncased", truncation_side=truncation_side, local_files_only=True) 29 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 30 | return tokenizer 31 | 32 | @property 33 | def device(self): 34 | return list(self.parameters())[0].device 35 | 36 | def maybe_autocast(self, dtype=torch.float16): 37 | # if on cpu, don't use autocast 38 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 39 | enable_autocast = self.device != torch.device("cpu") 40 | 41 | if enable_autocast: 42 | return torch.cuda.amp.autocast(dtype=dtype) 43 | else: 44 | return contextlib.nullcontext() 45 | 46 | @classmethod 47 | def init_Qformer( 48 | cls, 49 | num_query_token, vision_width, 50 | qformer_hidden_dropout_prob=0.1, 51 | qformer_attention_probs_dropout_prob=0.1, 52 | qformer_drop_path_rate=0., 53 | ): 54 | # encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True) 55 | encoder_config = BertConfig.from_pretrained("/media/sata1/cdp/transformer/bert-base-uncased", local_files_only=True) 56 | encoder_config.encoder_width = vision_width 57 | # insert cross-attention layer every other block 58 | encoder_config.add_cross_attention = True 59 | encoder_config.cross_attention_freq = 2 60 | encoder_config.query_length = num_query_token 61 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob 62 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob 63 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)] 64 | logger.info(f"Drop_path:{encoder_config.drop_path_list}") 65 | logger.info(encoder_config) 66 | Qformer = BertLMHeadModel(config=encoder_config) 67 | query_tokens = nn.Parameter( 68 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 69 | ) 70 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 71 | return Qformer, query_tokens 72 | 73 | @classmethod 74 | def init_vision_encoder_umt(self, config): 75 | """build vision encoder 76 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. 77 | 78 | """ 79 | vision_encoder = build_vit(config) 80 | 81 | if config.vision_encoder.vit_add_ln: 82 | vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12) 83 | else: 84 | vision_layernorm = nn.Identity() 85 | 86 | return vision_encoder, vision_layernorm 87 | 88 | 89 | def disabled_train(self, mode=True): 90 | """Overwrite model.train with this function to make sure train/eval mode 91 | does not change anymore.""" 92 | return self 93 | 94 | 95 | class LayerNorm(nn.LayerNorm): 96 | """Subclass torch's LayerNorm to handle fp16.""" 97 | 98 | def forward(self, x: torch.Tensor): 99 | orig_type = x.dtype 100 | ret = super().forward(x.type(torch.float32)) 101 | return ret.type(orig_type) 102 | -------------------------------------------------------------------------------- /GUI_Vid/scripts/config_7b_stage3.py: -------------------------------------------------------------------------------- 1 | from configs.instruction_data import * 2 | 3 | # ========================= data ========================== 4 | train_corpus = "videochat2_instruction" 5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 6 | test_file = dict() 7 | test_types = [] 8 | num_workers = 6 9 | 10 | stop_key = None 11 | 12 | # ========================= input ========================== 13 | num_frames = 8 14 | num_frames_test = 8 15 | batch_size = 4 16 | max_txt_l = 512 17 | 18 | pre_text = False 19 | 20 | inputs = dict( 21 | image_res=224, 22 | video_input=dict( 23 | num_frames="${num_frames}", 24 | sample_type="rand", 25 | num_frames_test="${num_frames_test}", 26 | sample_type_test="middle", 27 | random_aug=False, 28 | ), 29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 30 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 32 | ) 33 | 34 | # ========================= model ========================== 35 | model = dict( 36 | model_cls="VideoChat2_it", 37 | vit_blip_model_path="your_model_path/umt_l16_qformer.pth", 38 | llama_model_path="your_model_path/vicuna-7b-v0", 39 | videochat2_model_path="your_model_path/videochat2_7b_stage2.pth", 40 | freeze_vit=False, 41 | freeze_qformer=False, 42 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3 43 | # vit 44 | low_resource=False, 45 | add_temp_embed=False, 46 | vision_encoder=dict( 47 | name="vit_l14", 48 | img_size=224, 49 | patch_size=16, 50 | d_model=1024, 51 | encoder_embed_dim=1024, 52 | encoder_depth=24, 53 | encoder_num_heads=16, 54 | drop_path_rate=0., 55 | num_frames="${num_frames}", 56 | tubelet_size=1, 57 | use_checkpoint=False, 58 | checkpoint_num=0, 59 | pretrained="", 60 | return_index=-2, 61 | vit_add_ln=True, 62 | ckpt_num_frame=4, 63 | ), 64 | # qformer 65 | num_query_token=32, 66 | qformer_hidden_dropout_prob=0.1, 67 | qformer_attention_probs_dropout_prob=0.1, 68 | qformer_drop_path_rate=0.2, 69 | extra_num_query_token=64, 70 | qformer_text_input=True, 71 | # prompt 72 | system="", 73 | start_token="", 75 | add_second_msg=True, 76 | img_start_token="", 77 | img_end_token="", 78 | random_shuffle=True, 79 | use_flash_attention=True, 80 | use_lora=True, 81 | lora_r=16, 82 | lora_alpha=32, 83 | lora_dropout=0.1, 84 | # debug=True, 85 | ) 86 | 87 | optimizer = dict( 88 | opt="adamW", 89 | lr=2e-5, 90 | opt_betas=[0.9, 0.999], # default 91 | weight_decay=0.02, 92 | max_grad_norm=-1, # requires a positive float, use -1 to disable 93 | # use a different lr for some modules, e.g., larger lr for new modules 94 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 95 | ) 96 | 97 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) 98 | 99 | evaluate = False 100 | deep_fusion = False 101 | evaluation = dict( 102 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 103 | eval_x_only=False, 104 | k_test=128, 105 | eval_offload=True, # offload gpu tensors to cpu to save memory. 106 | ) 107 | 108 | fp16 = True 109 | gradient_checkpointing = True 110 | 111 | # ========================= wandb ========================== 112 | wandb = dict( 113 | enable=False, 114 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 115 | project="videochat2", # setup in your command line 116 | ) 117 | dist_url = "env://" 118 | device = "cuda" 119 | mode = "it" 120 | 121 | # ========================= others ========================== 122 | output_dir = None # output dir 123 | resume = False # if True, load optimizer and scheduler states as well 124 | debug = False 125 | log_freq = 100 126 | seed = 42 127 | 128 | save_latest = True 129 | auto_resume = True 130 | pretrained_path = "" # path to pretrained model weights, for resume only? 131 | -------------------------------------------------------------------------------- /GUI_Vid/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import logging 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def setup_for_distributed(is_master): 11 | import warnings 12 | 13 | builtin_warn = warnings.warn 14 | 15 | def warn(*args, **kwargs): 16 | force = kwargs.pop("force", False) 17 | if is_master or force: 18 | builtin_warn(*args, **kwargs) 19 | 20 | # Log warnings only once 21 | warnings.warn = warn 22 | warnings.simplefilter("once", UserWarning) 23 | 24 | if not is_master: 25 | logging.disable() 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | if not dist.is_available(): 30 | return False 31 | if not dist.is_initialized(): 32 | return False 33 | return True 34 | 35 | 36 | def get_world_size(): 37 | if not is_dist_avail_and_initialized(): 38 | return 1 39 | return dist.get_world_size() 40 | 41 | 42 | def get_rank(): 43 | if not is_dist_avail_and_initialized(): 44 | return 0 45 | return dist.get_rank() 46 | 47 | 48 | def is_main_process(): 49 | return get_rank() == 0 50 | 51 | 52 | def save_on_master(*args, **kwargs): 53 | if is_main_process(): 54 | torch.save(*args, **kwargs) 55 | 56 | 57 | def is_port_in_use(port): 58 | import socket 59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 60 | return s.connect_ex(('localhost', port)) == 0 61 | 62 | 63 | def init_distributed_mode(args): 64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 65 | # job started by torch.distributed.launch 66 | args.rank = int(os.environ["RANK"]) 67 | args.world_size = int(os.environ['WORLD_SIZE']) 68 | args.gpu = int(os.environ['LOCAL_RANK']) 69 | elif 'SLURM_PROCID' in os.environ: 70 | # local rank on the current node / global rank 71 | local_rank = int(os.environ['SLURM_LOCALID']) 72 | global_rank = int(os.environ['SLURM_PROCID']) 73 | # number of processes / GPUs per node 74 | world_size = int(os.environ["SLURM_NNODES"]) * \ 75 | int(os.environ["SLURM_TASKS_PER_NODE"][0]) 76 | 77 | print(world_size) 78 | 79 | args.rank = global_rank 80 | args.gpu = local_rank 81 | args.world_size = world_size 82 | else: 83 | logger.info('Not using distributed mode') 84 | args.distributed = False 85 | return 86 | 87 | args.distributed = True 88 | 89 | torch.cuda.set_device(args.gpu) 90 | args.dist_backend = 'nccl' 91 | 92 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node 93 | dist_port = int(args.dist_url.split(":")[-1]) 94 | while is_port_in_use(dist_port): 95 | dist_port += 10 96 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)]) 97 | 98 | logger.info('| distributed init (rank {}): {}'.format( 99 | args.rank, args.dist_url)) 100 | if "SLURM_JOB_ID" in os.environ: 101 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}") 102 | torch.distributed.init_process_group( 103 | backend=args.dist_backend, init_method=args.dist_url, 104 | world_size=args.world_size, rank=args.rank) 105 | torch.distributed.barrier() 106 | setup_for_distributed(args.rank == 0) 107 | 108 | 109 | # Copyright (c) Facebook, Inc. and its affiliates. 110 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py 111 | class GatherLayer(torch.autograd.Function): 112 | """ 113 | Gather tensors from all workers with support for backward propagation: 114 | This implementation does not cut the gradients as torch.distributed.all_gather does. 115 | """ 116 | 117 | @staticmethod 118 | def forward(ctx, x): 119 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 120 | dist.all_gather(output, x) 121 | return tuple(output) 122 | 123 | @staticmethod 124 | def backward(ctx, *grads): 125 | all_gradients = torch.stack(grads) 126 | dist.all_reduce(all_gradients) 127 | return all_gradients[dist.get_rank()] 128 | 129 | 130 | # copied from megavlt 131 | def gather_tensor_along_batch_with_backward(tensor, dim=0): 132 | world_size = get_world_size() 133 | 134 | if world_size < 2: 135 | return tensor 136 | 137 | tensor_list = GatherLayer.apply(tensor) 138 | tensor_list = torch.cat(tensor_list, dim=dim) 139 | return tensor_list 140 | 141 | 142 | @torch.no_grad() 143 | def gather_tensor_along_batch(tensor, dim=0): 144 | """ 145 | Performs all_gather operation on the provided tensors. 146 | *** Warning ***: torch.distributed.all_gather has no gradient. 147 | """ 148 | world_size = get_world_size() 149 | 150 | if world_size < 2: 151 | return tensor 152 | 153 | with torch.no_grad(): 154 | tensor_list = [] 155 | 156 | for _ in range(world_size): 157 | tensor_list.append(torch.zeros_like(tensor)) 158 | 159 | dist.all_gather(tensor_list, tensor) 160 | tensor_list = torch.cat(tensor_list, dim=dim) 161 | return tensor_list 162 | -------------------------------------------------------------------------------- /GUI_Vid/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | """ Optimizer Factory w/ Custom Weight Decay 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import re 5 | import torch 6 | from torch import optim as optim 7 | from utils.distributed import is_main_process 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | try: 11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 12 | has_apex = True 13 | except ImportError: 14 | has_apex = False 15 | 16 | 17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True): 18 | named_param_tuples = [] 19 | for name, param in model.named_parameters(): 20 | if not param.requires_grad: 21 | continue # frozen weights 22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")): 23 | named_param_tuples.append([name, param, 0]) 24 | elif name in no_decay_list: 25 | named_param_tuples.append([name, param, 0]) 26 | else: 27 | named_param_tuples.append([name, param, weight_decay]) 28 | return named_param_tuples 29 | 30 | 31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr): 32 | """use lr=diff_lr for modules named found in diff_lr_names, 33 | otherwise use lr=default_lr 34 | 35 | Args: 36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module 37 | diff_lr_names: List(str) 38 | diff_lr: float 39 | default_lr: float 40 | Returns: 41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr]) 42 | """ 43 | named_param_tuples_with_lr = [] 44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}") 45 | for name, p, wd in named_param_tuples_or_model: 46 | use_diff_lr = False 47 | for diff_name in diff_lr_names: 48 | # if diff_name in name: 49 | if re.search(diff_name, name) is not None: 50 | logger.info(f"param {name} use different_lr: {diff_lr}") 51 | use_diff_lr = True 52 | break 53 | 54 | named_param_tuples_with_lr.append( 55 | [name, p, wd, diff_lr if use_diff_lr else default_lr] 56 | ) 57 | 58 | if is_main_process(): 59 | for name, _, wd, diff_lr in named_param_tuples_with_lr: 60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}") 61 | 62 | return named_param_tuples_with_lr 63 | 64 | 65 | def create_optimizer_params_group(named_param_tuples_with_lr): 66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])""" 67 | group = {} 68 | for name, p, wd, lr in named_param_tuples_with_lr: 69 | if wd not in group: 70 | group[wd] = {} 71 | if lr not in group[wd]: 72 | group[wd][lr] = [] 73 | group[wd][lr].append(p) 74 | 75 | optimizer_params_group = [] 76 | for wd, lr_groups in group.items(): 77 | for lr, p in lr_groups.items(): 78 | optimizer_params_group.append(dict( 79 | params=p, 80 | weight_decay=wd, 81 | lr=lr 82 | )) 83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}") 84 | return optimizer_params_group 85 | 86 | 87 | def create_optimizer(args, model, filter_bias_and_bn=True): 88 | opt_lower = args.opt.lower() 89 | weight_decay = args.weight_decay 90 | # check for modules that requires different lr 91 | if hasattr(args, "different_lr") and args.different_lr.enable: 92 | diff_lr_module_names = args.different_lr.module_names 93 | diff_lr = args.different_lr.lr 94 | else: 95 | diff_lr_module_names = [] 96 | diff_lr = None 97 | 98 | no_decay = {} 99 | if hasattr(model, 'no_weight_decay'): 100 | no_decay = model.no_weight_decay() 101 | named_param_tuples = add_weight_decay( 102 | model, weight_decay, no_decay, filter_bias_and_bn) 103 | named_param_tuples = add_different_lr( 104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr) 105 | parameters = create_optimizer_params_group(named_param_tuples) 106 | 107 | if 'fused' in opt_lower: 108 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 109 | 110 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 111 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 112 | opt_args['eps'] = args.opt_eps 113 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 114 | opt_args['betas'] = args.opt_betas 115 | if hasattr(args, 'opt_args') and args.opt_args is not None: 116 | opt_args.update(args.opt_args) 117 | 118 | opt_split = opt_lower.split('_') 119 | opt_lower = opt_split[-1] 120 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 121 | opt_args.pop('eps', None) 122 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 123 | elif opt_lower == 'momentum': 124 | opt_args.pop('eps', None) 125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 126 | elif opt_lower == 'adam': 127 | optimizer = optim.Adam(parameters, **opt_args) 128 | elif opt_lower == 'adamw': 129 | optimizer = optim.AdamW(parameters, **opt_args) 130 | else: 131 | assert False and "Invalid optimizer" 132 | raise ValueError 133 | return optimizer 134 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 |
2 |

IconGUI-World: A Dataset for GUI-Orientated Multimodal Large Language Models 3 | 4 | [![Paper](https://img.shields.io/badge/Paper-%F0%9F%8E%93-lightgrey?style=flat-square)](https://arxiv.org/abs/2406.10819) [![Dataset](https://img.shields.io/badge/Dataset-%F0%9F%92%BE-green?style=flat-square)](https://huggingface.co/datasets/shuaishuaicdp/GUi-World) [![Website](https://img.shields.io/badge/Website-%F0%9F%90%BE-green?style=flat-square)](https://gui-world.github.io/) 5 | 6 | git-last-commit GitHub commit activity GitHub top language 7 | 8 | 9 | 10 |

11 | 12 |

13 |

14 | 15 | ## Updates & News 16 | 17 | **We will release our benchmark code soon.** 18 | - [18/10/2024] We release the benchmark code. 19 | - [16/06/2024] 📄 Paper on [arxiv](https://arxiv.org/abs/2406.10819) has released! 20 | 21 | ## Contents 22 | 23 | - [Updates \& News](#updates--news) 24 | - [Contents](#contents) 25 | - [Dataset: GUI-World](#dataset-gui-world) 26 | - [Overview](#overview) 27 | - [How to use GUI-World](#how-to-use-gui-world) 28 | - [GUI-Vid: A GUI-Oriented VideoLLM](#gui-vid-a-gui-oriented-videollm) 29 | - [Contribution](#contribution) 30 | - [Acknowledgments](#acknowledgments) 31 | - [Citation](#citation) 32 | 33 | ## Dataset: GUI-World 34 | 35 | ### Overview 36 | 37 | GUI-World introduces a comprehensive benchmark for evaluating MLLMs in dynamic and complex GUI environments. It features extensive annotations covering six GUI scenarios and eight types of GUI-oriented questions. The dataset assesses state-of-the-art ImageLLMs and VideoLLMs, highlighting their limitations in handling dynamic and multi-step tasks. It provides valuable insights and a foundation for future research in enhancing the understanding and interaction capabilities of MLLMs with dynamic GUI content. This dataset aims to advance the development of robust GUI agents capable of perceiving and interacting with both static and dynamic GUI elements. 38 | 39 | ### How to use GUI-World 40 | 41 | GUI-World is splited to train and test set, which can be accessed from [huggingface](https://huggingface.co/datasets/shuaishuaicdp/GUI-World). 42 | 43 | ## GUI-Vid: A GUI-Oriented VideoLLM 44 | 45 | GUI-Vid is a VideoLLM finetuned from [Videochat2](https://github.com/OpenGVLab/Ask-Anything). You can reproduce our experiment results following these instructions: 46 | **Prepare the Environment** 47 | 48 | ```shell 49 | git clone https://github.com/Dongping-Chen/GUI-World.git 50 | cd GUI-World/GUI_Vid 51 | conda create -n gui python=3.9 52 | conda activate gui 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | **GUI-Oriented Finetuning** 57 | 58 | - Download [GUI-World] and modify the root path in `GUI_Vid/configs/instruction_data.py`, which is the root dir for your download GUI-World. 59 | - Set `vit_blip_model_path`, `llama_model_path` and `videochat2_model_path` in `GUI_Vid/scripts/config_7b_stage3.py`, these checkpoints can be download from [GUI-Vid](https://huggingface.co/shuaishuaicdp/GUI-Vid). 60 | 61 | ```shell 62 | # Vicuna 63 | bash GUI_Vid/scripts/run_7b_stage3.sh 64 | ``` 65 | 66 | **Inference with GUI-Vid** 67 | You can first download checkpoint from [Huggingface](https://huggingface.co/shuaishuaicdp/GUI-Vid). You also need to set the config according to the guidance in [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2). 68 | Then, set the `model_path` in `scripts/demo_local.py`. Use the following script to inference a GUI video: 69 | 70 | ```shell 71 | python demo_local.py \ 72 | --ckpt_path \ 73 | --keyframe 8 \ 74 | --video_path \ 75 | --qs 76 | ``` 77 | 78 | ### How our video identifier works? 79 | 80 | In our paper, we use five settings to extract keyframes in video. For `Human` and `Linspace` (we employed uniform sampling to select 10 frames from each video, maintaining equal intervals between frames. This is the previous `Random` setting and we now use `Linspace` replacing it to avoid confusion), you can refer to the original file of our annotation and perform it by `np.linspace`. For `Program`, we use [Katna](https://github.com/keplerlab/katna) to extract keyframes and our code is in `GUI_Vid/scripts/katna.py`. For [VIP](https://github.com/facebookresearch/vip) and [R3M](https://github.com/facebookresearch/r3m) based on [UVD](https://github.com/zcczhang/UVD), which are additional experiments in [NeurIPS Rebuttal](https://openreview.net/forum?id=h8LuywKj6N¬eId=IG1slwXfWC), we extract keyframes locally and you can download them from [this link](https://1drv.ms/u/c/32f66c0c65d8cc2b/EUkoaMigq6hAg3GQx54pEz8BG6FMgXohIfnJ1MB5H092Rw?e=p7exRF). 81 | 82 | ## Contribution 83 | 84 | Contributions to this project are welcome. Please consider the following ways to contribute: 85 | 86 | - Proposing new features or improvements 87 | - Benchmark other mainstream MLLMs 88 | 89 | ## Acknowledgments 90 | 91 | Many thanks to Yinuo Liu, Zhengyan Fu, Shilin Zhang, Yu, Tianhe Gu, Haokuan Yuan, and Junqi Wang for their invalueble effort in this project. This project is based on methodologies and code presented in [Videochat2](https://github.com/OpenGVLab/Ask-Anything). 92 | 93 | ## Citation 94 | 95 | ``` 96 | @article{chen2024gui, 97 | title={GUI-WORLD: A Dataset for GUI-oriented Multimodal LLM-based Agents}, 98 | author={Chen, Dongping and Huang, Yue and Wu, Siyuan and Tang, Jingyu and Chen, Liuyi and Bai, Yilin and He, Zhigang and Wang, Chenlong and Zhou, Huichi and Li, Yiqiang and others}, 99 | journal={arXiv preprint arXiv:2406.10819}, 100 | year={2024} 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /GUI_Vid/dataset/it_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import sqlite3 5 | import random 6 | from os.path import basename 7 | 8 | import numpy as np 9 | import datetime 10 | 11 | from dataset.base_dataset import ImageVideoBaseDataset 12 | from dataset.utils import load_anno 13 | from dataset.video_utils import VIDEO_READER_FUNCS 14 | from utils.distributed import is_main_process 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class ITImgTrainDataset(ImageVideoBaseDataset): 20 | media_type = "image" 21 | 22 | def __init__( 23 | self, ann_file, transform, 24 | system="", role=("Human", "Assistant"), 25 | start_token="", end_token="", 26 | random_shuffle=True, # if True, shuffle the QA list 27 | ): 28 | super().__init__() 29 | 30 | if len(ann_file) == 3 and ann_file[2] == "video": 31 | self.media_type = "video" 32 | else: 33 | self.media_type = "image" 34 | self.label_file, self.data_root = ann_file[:2] 35 | 36 | logger.info('Load json file') 37 | with open(self.label_file, 'r') as f: 38 | self.anno = json.load(f) 39 | self.num_examples = len(self.anno) 40 | self.transform = transform 41 | 42 | # prompt parameters 43 | if system: 44 | assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token." 45 | # currently not support add start_token and end_token in the system, since the msg should be added properly 46 | self.begin_signal = "###" 47 | self.end_signal = " " 48 | self.start_token = start_token 49 | self.end_token = end_token 50 | self.system = system 51 | self.role = role 52 | self.random_shuffle = random_shuffle 53 | # instruction location and number 54 | logger.info(f"Random shuffle: {self.random_shuffle}") 55 | 56 | def get_anno(self, index): 57 | filename = self.anno[index][self.media_type] 58 | qa = self.anno[index]["QA"] 59 | if "start" in self.anno[index] and "end" in self.anno[index]: 60 | anno = { 61 | "image": os.path.join(self.data_root, filename), "qa": qa, 62 | "start": self.anno[index]["start"], "end": self.anno[index]["end"], 63 | } 64 | else: 65 | anno = {"image": os.path.join(self.data_root, filename), "qa": qa} 66 | return anno 67 | 68 | def __len__(self): 69 | return self.num_examples 70 | 71 | def process_qa(self, qa, msg=""): 72 | cur_instruction = "" 73 | # randomly shuffle qa for conversation 74 | if self.random_shuffle and len(qa) > 1: 75 | random.shuffle(qa) 76 | if "i" in qa[0].keys() and qa[0]["i"] != "": 77 | cur_instruction = qa[0]["i"] + self.end_signal 78 | 79 | conversation = self.system 80 | # add instruction as system message 81 | if cur_instruction: 82 | conversation += cur_instruction 83 | 84 | # rstrip() for the extra " " in msg 85 | conversation += ( 86 | self.begin_signal + self.role[0] + ": " + 87 | self.start_token + self.end_token + msg.rstrip() + self.end_signal 88 | ) 89 | 90 | for sentence in qa: 91 | q = sentence["q"] 92 | a = sentence["a"] 93 | if q != "": 94 | conversation += (self.begin_signal + self.role[0] + ": " + q + self.end_signal) 95 | else: 96 | # no question, often in caption dataset 97 | pass 98 | conversation += (self.begin_signal + self.role[1] + ": " + a + self.end_signal) 99 | conversation += self.begin_signal 100 | 101 | if cur_instruction: 102 | cur_instruction += qa[0]["q"] 103 | return conversation, cur_instruction.strip() 104 | 105 | def __getitem__(self, index): 106 | try: 107 | ann = self.get_anno(index) 108 | image, index = self.load_and_transform_media_data_image(index, ann["image"]) 109 | conversation, instruction = self.process_qa(ann["qa"]) 110 | return image, conversation, instruction, index 111 | except Exception as e: 112 | logger.warning(f"Caught exception {e} when loading image {ann['image']}") 113 | index = np.random.randint(0, len(self)) 114 | return self.__getitem__(index) 115 | 116 | 117 | class ITVidTrainDataset(ITImgTrainDataset): 118 | media_type = "video" 119 | 120 | def __init__( 121 | self, ann_file, transform, 122 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, 123 | system="", role=("Human", "Assistant"), 124 | start_token="", 125 | add_second_msg=True, 126 | random_shuffle=True, 127 | ): 128 | super().__init__( 129 | ann_file, transform, 130 | system=system, role=role, 131 | start_token=start_token, end_token=end_token, 132 | random_shuffle=random_shuffle, 133 | ) 134 | self.num_frames = num_frames 135 | self.video_reader_type = video_reader_type 136 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 137 | self.sample_type = sample_type 138 | self.num_tries = num_tries 139 | self.add_second_msg = add_second_msg 140 | 141 | logger.info(f"Use {video_reader_type} for data in {ann_file}") 142 | if add_second_msg: 143 | logger.info(f"Add second message: The video contains X frames sampled at T seconds.") 144 | 145 | def __getitem__(self, index): 146 | try: 147 | ann = self.get_anno(index) 148 | msg = "" 149 | clip = None 150 | if "start" in ann and "end" in ann: 151 | clip = [ann["start"], ann["end"]] 152 | video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip) 153 | if self.add_second_msg: 154 | # " " should be added in the start and end 155 | msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. " 156 | conversation, instruction = self.process_qa(ann["qa"], msg) 157 | return video, conversation, instruction, index 158 | except Exception as e: 159 | logger.warning(f"Caught exception {e} when loading video {ann['image']}") 160 | index = np.random.randint(0, len(self)) 161 | return self.__getitem__(index) -------------------------------------------------------------------------------- /GUI_Vid/dataset/pt_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import sqlite3 5 | import random 6 | from os.path import basename 7 | 8 | import numpy as np 9 | 10 | from dataset.base_dataset import ImageVideoBaseDataset 11 | from dataset.utils import load_anno, pre_text 12 | from dataset.video_utils import VIDEO_READER_FUNCS 13 | from utils.distributed import is_main_process 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def get_anno_by_id(cur: sqlite3.Cursor, id: int): 19 | """TODO: Docstring for get_anno_by_id. 20 | 21 | Args: 22 | cur (sqlite3.Cursor): The dataset cursor. 23 | id (int): The annotation id. 24 | 25 | Returns: 26 | 27 | """ 28 | pass 29 | 30 | 31 | class PTImgTrainDataset(ImageVideoBaseDataset): 32 | media_type = "image" 33 | 34 | def __init__(self, ann_file, transform, pre_text=True): 35 | super().__init__() 36 | 37 | if len(ann_file) == 3 and ann_file[2] == "video": 38 | self.media_type = "video" 39 | else: 40 | self.media_type = "image" 41 | self.label_file, self.data_root = ann_file[:2] 42 | 43 | logger.info('Load json file') 44 | with open(self.label_file, 'r') as f: 45 | self.anno = json.load(f) 46 | self.num_examples = len(self.anno) 47 | 48 | self.transform = transform 49 | self.pre_text = pre_text 50 | logger.info(f"Pre-process text: {pre_text}") 51 | 52 | def get_anno(self, index): 53 | filename = self.anno[index][self.media_type] 54 | caption = self.anno[index]["caption"] 55 | anno = {"image": os.path.join(self.data_root, filename), "caption": caption} 56 | return anno 57 | 58 | def __len__(self): 59 | return self.num_examples 60 | 61 | def __getitem__(self, index): 62 | try: 63 | ann = self.get_anno(index) 64 | image, index = self.load_and_transform_media_data(index, ann["image"]) 65 | caption = pre_text(ann["caption"], pre_text=self.pre_text) 66 | return image, caption, index 67 | except Exception as e: 68 | logger.warning(f"Caught exception {e} when loading image {ann['image']}") 69 | index = np.random.randint(0, len(self)) 70 | return self.__getitem__(index) 71 | 72 | 73 | class PTVidTrainDataset(PTImgTrainDataset): 74 | media_type = "video" 75 | 76 | def __init__( 77 | self, 78 | ann_file, 79 | transform, 80 | num_frames=4, 81 | video_reader_type="decord", 82 | sample_type="rand", 83 | num_tries=3, 84 | pre_text=True 85 | ): 86 | super().__init__(ann_file, transform, pre_text=pre_text) 87 | self.num_frames = num_frames 88 | self.video_reader_type = video_reader_type 89 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 90 | self.sample_type = sample_type 91 | self.num_tries = num_tries 92 | 93 | 94 | class PTImgEvalDataset(ImageVideoBaseDataset): 95 | media_type = "image" 96 | 97 | def __init__(self, ann_file, transform, has_multi_vision_gt=False): 98 | super(PTImgEvalDataset, self).__init__() 99 | self.raw_anno_list = load_anno(ann_file) 100 | self.transform = transform 101 | self.has_multi_vision_gt = has_multi_vision_gt # each caption has multiple image as ground_truth 102 | 103 | self.text = None 104 | self.image = None 105 | self.txt2img = None 106 | self.img2txt = None 107 | self.build_data() 108 | 109 | def build_data(self): 110 | self.text = [] 111 | self.image = [] 112 | self.txt2img = {} 113 | self.img2txt = {} 114 | if self.has_multi_vision_gt: 115 | self.build_data_multi_img_gt() 116 | else: 117 | self.build_data_multi_txt_gt() 118 | self.anno_list = [dict(image=e) for e in self.image] 119 | 120 | def build_data_multi_img_gt(self): 121 | """each text may have multiple ground_truth image, e.g., ssv2""" 122 | img_id = 0 123 | for txt_id, ann in enumerate(self.raw_anno_list): 124 | self.text.append(pre_text(ann["caption"])) 125 | self.txt2img[txt_id] = [] 126 | _images = ann["image"] \ 127 | if isinstance(ann["image"], list) else [ann["image"], ] 128 | for i, image in enumerate(_images): 129 | self.image.append(image) 130 | self.txt2img[txt_id].append(img_id) 131 | self.img2txt[img_id] = txt_id 132 | img_id += 1 133 | 134 | def build_data_multi_txt_gt(self): 135 | """each image may have multiple ground_truth text, e.g., COCO and Flickr30K""" 136 | txt_id = 0 137 | for img_id, ann in enumerate(self.raw_anno_list): 138 | self.image.append(ann["image"]) 139 | self.img2txt[img_id] = [] 140 | _captions = ann["caption"] \ 141 | if isinstance(ann["caption"], list) else [ann["caption"], ] 142 | for i, caption in enumerate(_captions): 143 | self.text.append(pre_text(caption)) 144 | self.img2txt[img_id].append(txt_id) 145 | self.txt2img[txt_id] = img_id 146 | txt_id += 1 147 | 148 | def __len__(self): 149 | return len(self.anno_list) 150 | 151 | def __getitem__(self, index): 152 | ann = self.anno_list[index] 153 | image, index = self.load_and_transform_media_data(index, ann["image"]) 154 | return image, index 155 | 156 | 157 | def preprocess_para_retrieval_data(anno_list): 158 | processed_anno_list = [] 159 | for d in anno_list: 160 | d["caption"] = " ".join(d.pop("caption")) 161 | processed_anno_list.append(d) 162 | return processed_anno_list 163 | 164 | 165 | class PTVidEvalDataset(PTImgEvalDataset): 166 | media_type = "video" 167 | 168 | def __init__( 169 | self, ann_file, transform, num_frames=4, 170 | video_reader_type="decord", sample_type="rand", num_tries=1, 171 | is_paragraph_retrieval=False, has_multi_vision_gt=False, 172 | ): 173 | super(PTVidEvalDataset, self).__init__(ann_file, transform, has_multi_vision_gt) 174 | self.num_frames = num_frames 175 | self.video_reader_type = video_reader_type 176 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 177 | self.sample_type = sample_type 178 | self.num_tries = num_tries 179 | self.is_paragraph_retrieval = is_paragraph_retrieval 180 | 181 | if is_paragraph_retrieval: 182 | self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list) 183 | self.build_data() 184 | -------------------------------------------------------------------------------- /GUI_Vid/dataset/video_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py 3 | """ 4 | import random 5 | import io 6 | import av 7 | import cv2 8 | import decord 9 | import imageio 10 | from decord import VideoReader 11 | import torch 12 | import numpy as np 13 | import math 14 | decord.bridge.set_bridge("torch") 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: 20 | """ 21 | Converts a present time with the given time base and start_pts offset to seconds. 22 | 23 | Returns: 24 | time_in_seconds (float): The corresponding time in seconds. 25 | 26 | https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 27 | """ 28 | if pts == math.inf: 29 | return math.inf 30 | 31 | return int(pts - start_pts) * time_base 32 | 33 | 34 | def get_pyav_video_duration(video_reader): 35 | video_stream = video_reader.streams.video[0] 36 | video_duration = pts_to_secs( 37 | video_stream.duration, 38 | video_stream.time_base, 39 | video_stream.start_time 40 | ) 41 | return float(video_duration) 42 | 43 | 44 | def get_frame_indices_by_fps(): 45 | pass 46 | 47 | 48 | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): 49 | if sample in ["rand", "middle"]: # uniform sampling 50 | acc_samples = min(num_frames, vlen) 51 | # split the video into `acc_samples` intervals, and sample from each interval. 52 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 53 | ranges = [] 54 | for idx, interv in enumerate(intervals[:-1]): 55 | ranges.append((interv, intervals[idx + 1] - 1)) 56 | if sample == 'rand': 57 | try: 58 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] 59 | except: 60 | frame_indices = np.random.permutation(vlen)[:acc_samples] 61 | frame_indices.sort() 62 | frame_indices = list(frame_indices) 63 | elif fix_start is not None: 64 | frame_indices = [x[0] + fix_start for x in ranges] 65 | elif sample == 'middle': 66 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges] 67 | else: 68 | raise NotImplementedError 69 | 70 | if len(frame_indices) < num_frames: # padded with last frame 71 | padded_frame_indices = [frame_indices[-1]] * num_frames 72 | padded_frame_indices[:len(frame_indices)] = frame_indices 73 | frame_indices = padded_frame_indices 74 | elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps 75 | output_fps = float(sample[3:]) 76 | duration = float(vlen) / input_fps 77 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents 78 | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) 79 | frame_indices = np.around(frame_seconds * input_fps).astype(int) 80 | frame_indices = [e for e in frame_indices if e < vlen] 81 | if max_num_frames > 0 and len(frame_indices) > max_num_frames: 82 | frame_indices = frame_indices[:max_num_frames] 83 | # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) 84 | else: 85 | raise ValueError 86 | return frame_indices 87 | 88 | 89 | def read_frames_av( 90 | video_path, num_frames, sample='rand', fix_start=None, 91 | max_num_frames=-1, client=None, clip=None, 92 | ): 93 | reader = av.open(video_path) 94 | frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)] 95 | vlen = len(frames) 96 | duration = get_pyav_video_duration(reader) 97 | fps = vlen / float(duration) 98 | frame_indices = get_frame_indices( 99 | num_frames, vlen, sample=sample, fix_start=fix_start, 100 | input_fps=fps, max_num_frames=max_num_frames 101 | ) 102 | frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8 103 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 104 | return frames, frame_indices, fps 105 | 106 | 107 | def read_frames_gif( 108 | video_path, num_frames, sample='rand', fix_start=None, 109 | max_num_frames=-1, client=None, clip=None, 110 | ): 111 | if video_path.startswith('s3') or video_path.startswith('p2'): 112 | video_bytes = client.get(video_path) 113 | gif = imageio.get_reader(io.BytesIO(video_bytes)) 114 | else: 115 | gif = imageio.get_reader(video_path) 116 | vlen = len(gif) 117 | frame_indices = get_frame_indices( 118 | num_frames, vlen, sample=sample, fix_start=fix_start, 119 | max_num_frames=max_num_frames 120 | ) 121 | frames = [] 122 | for index, frame in enumerate(gif): 123 | # for index in frame_idxs: 124 | if index in frame_indices: 125 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) 126 | frame = torch.from_numpy(frame).byte() 127 | # # (H x W x C) to (C x H x W) 128 | frame = frame.permute(2, 0, 1) 129 | frames.append(frame) 130 | frames = torch.stack(frames) # .float() / 255 131 | return frames, frame_indices, 25. # for tgif 132 | 133 | 134 | def read_frames_decord( 135 | video_path, num_frames, sample='rand', fix_start=None, 136 | max_num_frames=-1, client=None, clip=None 137 | ): 138 | if video_path.startswith('s3') or video_path.startswith('p2'): 139 | video_bytes = client.get(video_path) 140 | video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) 141 | else: 142 | video_reader = VideoReader(video_path, num_threads=1) 143 | vlen = len(video_reader) 144 | fps = video_reader.get_avg_fps() 145 | duration = vlen / float(fps) 146 | 147 | if clip: 148 | start, end = clip 149 | duration = end - start 150 | vlen = int(duration * fps) 151 | start_index = int(start * fps) 152 | 153 | frame_indices = get_frame_indices( 154 | num_frames, vlen, sample=sample, fix_start=fix_start, 155 | input_fps=fps, max_num_frames=max_num_frames 156 | ) 157 | if clip: 158 | frame_indices = [f + start_index for f in frame_indices] 159 | 160 | frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 161 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 162 | return frames, frame_indices, float(fps) 163 | 164 | 165 | VIDEO_READER_FUNCS = { 166 | 'av': read_frames_av, 167 | 'decord': read_frames_decord, 168 | 'gif': read_frames_gif, 169 | } 170 | -------------------------------------------------------------------------------- /GUI_Vid/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from utils.distributed import is_main_process, get_rank, get_world_size 2 | import logging 3 | import torch.distributed as dist 4 | import torch 5 | import io 6 | import os 7 | import json 8 | import re 9 | import numpy as np 10 | from os.path import join 11 | from tqdm import trange 12 | from PIL import Image 13 | from PIL import ImageFile 14 | from torchvision.transforms import PILToTensor 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | Image.MAX_IMAGE_PIXELS = None 17 | 18 | 19 | def load_image_from_path(image_path, client): 20 | if image_path.startswith('s3') or image_path.startswith('p2'): 21 | value = client.Get(image_path) 22 | img_bytes = np.frombuffer(value, dtype=np.uint8) 23 | buff = io.BytesIO(img_bytes) 24 | image = Image.open(buff).convert('RGB') 25 | else: 26 | image = Image.open(image_path).convert('RGB') # PIL Image 27 | image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8 28 | return image 29 | 30 | 31 | def load_anno(ann_file_list): 32 | """[summary] 33 | 34 | Args: 35 | ann_file_list (List[List[str, str]] or List[str, str]): 36 | the latter will be automatically converted to the former. 37 | Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video']) 38 | which specifies the data type, video or image 39 | 40 | Returns: 41 | List(dict): each dict is { 42 | image: str or List[str], # image_path, 43 | caption: str or List[str] # caption text string 44 | } 45 | """ 46 | if isinstance(ann_file_list[0], str): 47 | ann_file_list = [ann_file_list] 48 | 49 | ann = [] 50 | for d in ann_file_list: 51 | data_root = d[1] 52 | fp = d[0] 53 | is_video = len(d) == 3 and d[2] == "video" 54 | cur_ann = json.load(open(fp, "r")) 55 | iterator = trange(len(cur_ann), desc=f"Loading {fp}") \ 56 | if is_main_process() else range(len(cur_ann)) 57 | for idx in iterator: 58 | key = "video" if is_video else "image" 59 | # unified to have the same key for data path 60 | if isinstance(cur_ann[idx][key], str): 61 | cur_ann[idx]["image"] = join(data_root, cur_ann[idx][key]) 62 | else: # list 63 | cur_ann[idx]["image"] = [join(data_root, e) for e in cur_ann[idx][key]] 64 | ann += cur_ann 65 | return ann 66 | 67 | 68 | def pre_text(text, max_l=None, pre_text=True): 69 | if pre_text: 70 | text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) 71 | text = text.replace('-', ' ').replace('/', ' ').replace('', 'person') 72 | 73 | text = re.sub(r"\s{2,}", ' ', text) 74 | text = text.rstrip('\n').strip(' ') 75 | 76 | if max_l: # truncate 77 | words = text.split(' ') 78 | if len(words) > max_l: 79 | text = ' '.join(words[:max_l]) 80 | else: 81 | pass 82 | return text 83 | 84 | 85 | logger = logging.getLogger(__name__) 86 | 87 | 88 | def collect_result(result, result_dir, filename, is_json=True, is_list=True): 89 | if is_json: 90 | result_file = os.path.join( 91 | result_dir, '%s_rank%d.json' % (filename, get_rank())) 92 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 93 | json.dump(result, open(result_file, 'w')) 94 | else: 95 | result_file = os.path.join( 96 | result_dir, '%s_rank%d.pth' % (filename, get_rank())) 97 | final_result_file = os.path.join(result_dir, '%s.pth' % filename) 98 | torch.save(result, result_file) 99 | 100 | dist.barrier() 101 | 102 | result = None 103 | if is_main_process(): 104 | # combine results from all processes 105 | if is_list: 106 | result = [] 107 | else: 108 | result = {} 109 | for rank in range(get_world_size()): 110 | if is_json: 111 | result_file = os.path.join( 112 | result_dir, '%s_rank%d.json' % (filename, rank)) 113 | res = json.load(open(result_file, 'r')) 114 | else: 115 | result_file = os.path.join( 116 | result_dir, '%s_rank%d.pth' % (filename, rank)) 117 | res = torch.load(result_file) 118 | if is_list: 119 | result += res 120 | else: 121 | result.update(res) 122 | 123 | return result 124 | 125 | 126 | def sync_save_result(result, result_dir, filename, is_json=True, is_list=True): 127 | """gather results from multiple GPUs""" 128 | if is_json: 129 | result_file = os.path.join( 130 | result_dir, "dist_res", '%s_rank%d.json' % (filename, get_rank())) 131 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 132 | os.makedirs(os.path.dirname(result_file), exist_ok=True) 133 | json.dump(result, open(result_file, 'w')) 134 | else: 135 | result_file = os.path.join( 136 | result_dir, "dist_res", '%s_rank%d.pth' % (filename, get_rank())) 137 | os.makedirs(os.path.dirname(result_file), exist_ok=True) 138 | final_result_file = os.path.join(result_dir, '%s.pth' % filename) 139 | torch.save(result, result_file) 140 | 141 | dist.barrier() 142 | 143 | if is_main_process(): 144 | # combine results from all processes 145 | if is_list: 146 | result = [] 147 | else: 148 | result = {} 149 | for rank in range(get_world_size()): 150 | if is_json: 151 | result_file = os.path.join( 152 | result_dir, "dist_res", '%s_rank%d.json' % (filename, rank)) 153 | res = json.load(open(result_file, 'r')) 154 | else: 155 | result_file = os.path.join( 156 | result_dir, "dist_res", '%s_rank%d.pth' % (filename, rank)) 157 | res = torch.load(result_file) 158 | if is_list: 159 | result += res 160 | else: 161 | result.update(res) 162 | if is_json: 163 | json.dump(result, open(final_result_file, 'w')) 164 | else: 165 | torch.save(result, final_result_file) 166 | 167 | logger.info('result file saved to %s' % final_result_file) 168 | dist.barrier() 169 | return final_result_file, result 170 | 171 | 172 | def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None): 173 | """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray) 174 | into a (n+1)-d array, only allow the first dim has variable lengths. 175 | Args: 176 | sequences: list(n-d tensor or list) 177 | dtype: np.dtype or torch.dtype 178 | device: 179 | fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length. 180 | return will be of shape [len(sequences), fixed_length, ...] 181 | Returns: 182 | padded_seqs: ((n+1)-d tensor) padded with zeros 183 | mask: (2d tensor) of the same shape as the first two dims of padded_seqs, 184 | 1 indicate valid, 0 otherwise 185 | Examples: 186 | >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] 187 | >>> pad_sequences_1d(test_data_list, dtype=torch.long) 188 | >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)] 189 | >>> pad_sequences_1d(test_data_3d, dtype=torch.float) 190 | >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] 191 | >>> pad_sequences_1d(test_data_list, dtype=np.float32) 192 | >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)] 193 | >>> pad_sequences_1d(test_data_3d, dtype=np.float32) 194 | """ 195 | if isinstance(sequences[0], list): 196 | if "torch" in str(dtype): 197 | sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences] 198 | else: 199 | sequences = [np.asarray(s, dtype=dtype) for s in sequences] 200 | 201 | extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements 202 | lengths = [len(seq) for seq in sequences] 203 | if fixed_length is not None: 204 | max_length = fixed_length 205 | else: 206 | max_length = max(lengths) 207 | if isinstance(sequences[0], torch.Tensor): 208 | assert "torch" in str(dtype), "dtype and input type does not match" 209 | padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device) 210 | mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device) 211 | else: # np 212 | assert "numpy" in str(dtype), "dtype and input type does not match" 213 | padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype) 214 | mask = np.zeros((len(sequences), max_length), dtype=np.float32) 215 | 216 | for idx, seq in enumerate(sequences): 217 | end = lengths[idx] 218 | padded_seqs[idx, :end] = seq 219 | mask[idx, :end] = 1 220 | return padded_seqs, mask # , lengths 221 | 222 | 223 | -------------------------------------------------------------------------------- /GUI_Vid/utils/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import ast 5 | import json 6 | import os 7 | import os.path as osp 8 | import re 9 | import shutil 10 | import sys 11 | import tempfile 12 | from copy import deepcopy 13 | from importlib import import_module 14 | 15 | import yaml 16 | 17 | from .easydict import EasyDict 18 | 19 | __all__ = ["Config", "pretty_text"] 20 | 21 | 22 | BASE_KEY = "_base_" 23 | # BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"} 24 | BASE_CONFIG = {} 25 | 26 | cfg = None 27 | 28 | 29 | class Config(object): 30 | """config""" 31 | 32 | @classmethod 33 | def pretty_text(cls, cfg: dict, indent=2) -> str: 34 | """format dict to a string 35 | 36 | Args: 37 | cfg (EasyDict): the params. 38 | 39 | Returns: The string to display. 40 | 41 | """ 42 | msg = "{\n" 43 | for i, (k, v) in enumerate(cfg.items()): 44 | if isinstance(v, dict): 45 | v = cls.pretty_text(v, indent + 4) 46 | spaces = " " * indent 47 | msg += spaces + "{}: {}".format(k, v) 48 | if i == len(cfg) - 1: 49 | msg += " }" 50 | else: 51 | msg += "\n" 52 | return msg 53 | 54 | @classmethod 55 | def dump(cls, cfg, savepath=None): 56 | """dump cfg to `json` file. 57 | 58 | Args: 59 | cfg (dict): The dict to dump. 60 | savepath (str): The filepath to save the dumped dict. 61 | 62 | Returns: TODO 63 | 64 | """ 65 | if savepath is None: 66 | savepath = osp.join(cfg.WORKSPACE, "config.json") 67 | json.dump(cfg, open(savepath, "w"), indent=2) 68 | 69 | @classmethod 70 | def get_config(cls, default_config: dict = None): 71 | """get a `Config` instance. 72 | 73 | Args: 74 | default_config (dict): The default config. `default_config` will be overrided 75 | by config file `--cfg`, `--cfg` will be overrided by commandline args. 76 | 77 | Returns: an EasyDict. 78 | """ 79 | global cfg 80 | if cfg is not None: 81 | return cfg 82 | 83 | # define arg parser. 84 | parser = argparse.ArgumentParser() 85 | # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str) 86 | parser.add_argument( 87 | "config_file", help="the configuration file to load. support: .yaml, .json, .py" 88 | ) 89 | parser.add_argument( 90 | "opts", 91 | default=None, 92 | nargs="*", 93 | help="overrided configs. List. Format: 'key1 name1 key2 name2'", 94 | ) 95 | args = parser.parse_args() 96 | 97 | cfg = EasyDict(BASE_CONFIG) 98 | if osp.isfile(args.config_file): 99 | cfg_from_file = cls.from_file(args.config_file) 100 | cfg = merge_a_into_b(cfg_from_file, cfg) 101 | cfg = cls.merge_list(cfg, args.opts) 102 | cfg = eval_dict_leaf(cfg) 103 | 104 | # update some keys to make them show at the last 105 | for k in BASE_CONFIG: 106 | cfg[k] = cfg.pop(k) 107 | return cfg 108 | 109 | @classmethod 110 | def from_file(cls, filepath: str) -> EasyDict: 111 | """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`. 112 | 113 | Args: 114 | filepath (str): The config file path. 115 | 116 | Returns: TODO 117 | 118 | """ 119 | filepath = osp.abspath(osp.expanduser(filepath)) 120 | if not osp.isfile(filepath): 121 | raise IOError(f"File does not exist: {filepath}") 122 | if filepath.endswith(".py"): 123 | with tempfile.TemporaryDirectory() as temp_config_dir: 124 | 125 | shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config")) 126 | sys.path.insert(0, temp_config_dir) 127 | mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0]) 128 | # mod = import_module(temp_module_name) 129 | sys.path.pop(0) 130 | cfg_dict = { 131 | name: value 132 | for name, value in mod.__dict__.items() 133 | if not name.startswith("__") 134 | } 135 | for k in list(sys.modules.keys()): 136 | if "tmp_config" in k: 137 | del sys.modules[k] 138 | elif filepath.endswith((".yml", ".yaml")): 139 | cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader) 140 | elif filepath.endswith(".json"): 141 | cfg_dict = json.load(open(filepath, "r")) 142 | else: 143 | raise IOError("Only py/yml/yaml/json type are supported now!") 144 | 145 | cfg_text = filepath + "\n" 146 | with open(filepath, "r") as f: 147 | cfg_text += f.read() 148 | 149 | if BASE_KEY in cfg_dict: # load configs in `BASE_KEY` 150 | cfg_dir = osp.dirname(filepath) 151 | base_filename = cfg_dict.pop(BASE_KEY) 152 | base_filename = ( 153 | base_filename if isinstance(base_filename, list) else [base_filename] 154 | ) 155 | 156 | cfg_dict_list = list() 157 | for f in base_filename: 158 | _cfg_dict = Config.from_file(osp.join(cfg_dir, f)) 159 | cfg_dict_list.append(_cfg_dict) 160 | 161 | base_cfg_dict = dict() 162 | for c in cfg_dict_list: 163 | if len(base_cfg_dict.keys() & c.keys()) > 0: 164 | raise KeyError("Duplicate key is not allowed among bases") 165 | base_cfg_dict.update(c) 166 | 167 | cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict) 168 | 169 | return EasyDict(cfg_dict) 170 | 171 | @classmethod 172 | def merge_list(cls, cfg, opts: list): 173 | """merge commandline opts. 174 | 175 | Args: 176 | cfg: (dict): The config to be merged. 177 | opts (list): The list to merge. Format: [key1, name1, key2, name2,...]. 178 | The keys can be nested. For example, ["a.b", v] will be considered 179 | as `dict(a=dict(b=v))`. 180 | 181 | Returns: dict. 182 | 183 | """ 184 | assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}" 185 | for i in range(0, len(opts), 2): 186 | full_k, v = opts[i], opts[i + 1] 187 | keys = full_k.split(".") 188 | sub_d = cfg 189 | for i, k in enumerate(keys): 190 | if not hasattr(sub_d, k): 191 | raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}") 192 | if i != len(keys) - 1: 193 | sub_d = sub_d[k] 194 | else: 195 | sub_d[k] = v 196 | return cfg 197 | 198 | 199 | def merge_a_into_b(a, b, inplace=False): 200 | """The values in a will override values in b. 201 | 202 | Args: 203 | a (dict): source dict. 204 | b (dict): target dict. 205 | 206 | Returns: dict. recursively merge dict a into dict b. 207 | 208 | """ 209 | if not inplace: 210 | b = deepcopy(b) 211 | for key in a: 212 | if key in b: 213 | if isinstance(a[key], dict) and isinstance(b[key], dict): 214 | b[key] = merge_a_into_b(a[key], b[key], inplace=True) 215 | else: 216 | b[key] = a[key] 217 | else: 218 | b[key] = a[key] 219 | return b 220 | 221 | 222 | def eval_dict_leaf(d, orig_dict=None): 223 | """eval values of dict leaf. 224 | 225 | Args: 226 | d (dict): The dict to eval. 227 | 228 | Returns: dict. 229 | 230 | """ 231 | if orig_dict is None: 232 | orig_dict = d 233 | for k, v in d.items(): 234 | if not isinstance(v, dict): 235 | d[k] = eval_string(v, orig_dict) 236 | else: 237 | eval_dict_leaf(v, orig_dict) 238 | return d 239 | 240 | 241 | def eval_string(string, d): 242 | """automatically evaluate string to corresponding types. 243 | 244 | For example: 245 | not a string -> return the original input 246 | '0' -> 0 247 | '0.2' -> 0.2 248 | '[0, 1, 2]' -> [0,1,2] 249 | 'eval(1+2)' -> 3 250 | 'eval(range(5))' -> [0,1,2,3,4] 251 | '${a}' -> d.a 252 | 253 | 254 | 255 | Args: 256 | string (str): The value to evaluate. 257 | d (dict): The 258 | 259 | Returns: the corresponding type 260 | 261 | """ 262 | if not isinstance(string, str): 263 | return string 264 | # if len(string) > 1 and string[0] == "[" and string[-1] == "]": 265 | # return eval(string) 266 | if string[0:5] == "eval(": 267 | return eval(string[5:-1]) 268 | 269 | s0 = string 270 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0) 271 | if s1 != s0: 272 | while s1 != s0: 273 | s0 = s1 274 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0) 275 | return eval(s1) 276 | 277 | try: 278 | v = ast.literal_eval(string) 279 | except: 280 | v = string 281 | return v 282 | -------------------------------------------------------------------------------- /GUI_Vid/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import ConcatDataset, DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms import InterpolationMode 5 | 6 | from dataset.dataloader import MetaLoader 7 | from dataset.pt_dataset import PTImgTrainDataset, PTVidTrainDataset, PTImgEvalDataset, PTVidEvalDataset 8 | from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset 9 | 10 | 11 | def get_media_type(dataset_config): 12 | if len(dataset_config) == 3 and dataset_config[2] == "video": 13 | return "video" 14 | elif dataset_config[-1] == "only_video": 15 | return "only_video" 16 | else: 17 | return "image" 18 | 19 | 20 | def create_dataset(dataset_type, config): 21 | if "clip" in config.model.get("vit_model", 'vit'): 22 | mean = (0.485, 0.456, 0.406) 23 | std = (0.229, 0.224, 0.225) 24 | else: 25 | vision_enc_name = config.model.vision_encoder.name 26 | if "swin" in vision_enc_name or "vit" in vision_enc_name: 27 | mean = (0.485, 0.456, 0.406) 28 | std = (0.229, 0.224, 0.225) 29 | elif "beit" in vision_enc_name: 30 | mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning 31 | std = (0.5, 0.5, 0.5) 32 | elif "clip" in vision_enc_name: 33 | mean = (0.48145466, 0.4578275, 0.40821073) 34 | std = (0.26862954, 0.26130258, 0.27577711) 35 | else: 36 | raise ValueError 37 | 38 | normalize = transforms.Normalize(mean, std) 39 | 40 | # loaded images and videos are torch.Tensor of torch.uint8 format, 41 | # ordered as (T, 1 or 3, H, W) where T=1 for image 42 | type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) 43 | 44 | if config.inputs.video_input.random_aug: 45 | aug_transform = transforms.RandAugment() 46 | else: 47 | aug_transform = transforms.Lambda(lambda x: x) 48 | 49 | train_transform = transforms.Compose( 50 | [ 51 | aug_transform, 52 | transforms.RandomResizedCrop( 53 | config.inputs.image_res, 54 | scale=(0.5, 1.0), 55 | interpolation=InterpolationMode.BICUBIC, 56 | ), 57 | transforms.RandomHorizontalFlip(), 58 | type_transform, 59 | normalize, 60 | ] 61 | ) 62 | test_transform = transforms.Compose( 63 | [ 64 | transforms.Resize( 65 | (config.inputs.image_res, config.inputs.image_res), 66 | interpolation=InterpolationMode.BICUBIC, 67 | ), 68 | type_transform, 69 | normalize, 70 | ] 71 | ) 72 | 73 | video_reader_type = config.inputs.video_input.get("video_reader_type", "decord") 74 | video_only_dataset_kwargs_train = dict( 75 | video_reader_type=video_reader_type, 76 | sample_type=config.inputs.video_input.sample_type, 77 | num_frames=config.inputs.video_input.num_frames, 78 | num_tries=3, # false tolerance 79 | ) 80 | video_only_dataset_kwargs_eval = dict( 81 | video_reader_type=video_reader_type, 82 | sample_type=config.inputs.video_input.sample_type_test, 83 | num_frames=config.inputs.video_input.num_frames_test, 84 | num_tries=1, # we want to have predictions for all videos 85 | ) 86 | 87 | if dataset_type == "pt_train": 88 | # convert to list of lists 89 | train_files = ( 90 | [config.train_file] if isinstance(config.train_file[0], str) else config.train_file 91 | ) 92 | train_media_types = sorted(list({get_media_type(e) for e in train_files})) 93 | 94 | train_datasets = [] 95 | for m in train_media_types: 96 | dataset_cls = PTImgTrainDataset if m == "image" else PTVidTrainDataset 97 | # dataset of the same media_type will be mixed in a single Dataset object 98 | _train_files = [e for e in train_files if get_media_type(e) == m] 99 | 100 | datasets = [] 101 | for train_file in _train_files: 102 | dataset_kwargs = dict( 103 | ann_file=train_file, 104 | transform=train_transform, 105 | pre_text=config.get( 106 | "pre_text", True 107 | ), 108 | ) 109 | if m == "video": 110 | dataset_kwargs.update(video_only_dataset_kwargs_train) 111 | datasets.append(dataset_cls(**dataset_kwargs)) 112 | dataset = ConcatDataset(datasets) 113 | train_datasets.append(dataset) 114 | return train_datasets 115 | 116 | elif dataset_type in ["it_train"]: 117 | # convert to list of lists 118 | train_files = ( 119 | [config.train_file] if isinstance(config.train_file[0], str) else config.train_file 120 | ) 121 | train_media_types = sorted(list({get_media_type(e) for e in train_files})) 122 | 123 | train_datasets = [] 124 | for m in train_media_types: 125 | dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset 126 | # dataset of the same media_type will be mixed in a single Dataset object 127 | _train_files = [e for e in train_files if get_media_type(e) == m] 128 | 129 | datasets = [] 130 | for train_file in _train_files: 131 | dataset_kwargs = dict( 132 | ann_file=train_file, 133 | transform=train_transform, 134 | system=config.model.get("system", ""), 135 | start_token=config.model.get("img_start_token", ""), 136 | end_token=config.model.get("img_end_token", ""), 137 | ) 138 | if m == "video": 139 | video_only_dataset_kwargs_train.update({ 140 | "start_token": config.model.get("start_token", ""), 142 | }) 143 | dataset_kwargs.update(video_only_dataset_kwargs_train) 144 | if "tgif" in train_file[1]: 145 | video_only_dataset_kwargs_train.update({ 146 | "video_reader_type": "gif" 147 | }) 148 | dataset_kwargs.update(video_only_dataset_kwargs_train) 149 | else: 150 | video_only_dataset_kwargs_train.update({ 151 | "video_reader_type": "decord" 152 | }) 153 | dataset_kwargs.update(video_only_dataset_kwargs_train) 154 | datasets.append(dataset_cls(**dataset_kwargs)) 155 | dataset = ConcatDataset(datasets) 156 | train_datasets.append(dataset) 157 | return train_datasets 158 | 159 | elif dataset_type == "pt_eval": 160 | test_datasets = [] 161 | test_dataset_names = [] 162 | # multiple test datasets, all separate 163 | for name, data_cfg in config.test_file.items(): 164 | media_type = get_media_type(data_cfg) 165 | test_dataset_cls = ( 166 | PTImgEvalDataset if media_type == "image" else PTVidEvalDataset 167 | ) 168 | test_dataset_names.append(name) 169 | dataset_kwargs = dict( 170 | ann_file=[data_cfg], 171 | transform=test_transform, 172 | has_multi_vision_gt=config.get( 173 | "has_multi_vision_gt", False 174 | ), # true for ssv2 ret 175 | ) 176 | if media_type == "video": 177 | dataset_kwargs.update(video_only_dataset_kwargs_eval) 178 | test_datasets.append(test_dataset_cls(**dataset_kwargs)) 179 | return test_datasets, test_dataset_names 180 | 181 | 182 | 183 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 184 | samplers = [] 185 | for dataset, shuffle in zip(datasets, shuffles): 186 | sampler = torch.utils.data.DistributedSampler( 187 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle 188 | ) 189 | samplers.append(sampler) 190 | return samplers 191 | 192 | 193 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 194 | loaders = [] 195 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( 196 | datasets, samplers, batch_size, num_workers, is_trains, collate_fns 197 | ): 198 | if is_train: 199 | shuffle = sampler is None 200 | drop_last = True 201 | else: 202 | shuffle = False 203 | drop_last = False 204 | loader = DataLoader( 205 | dataset, 206 | batch_size=bs, 207 | num_workers=n_worker, 208 | pin_memory=False, 209 | sampler=sampler, 210 | shuffle=shuffle, 211 | collate_fn=collate_fn, 212 | drop_last=drop_last, 213 | persistent_workers=True if n_worker > 0 else False, 214 | ) 215 | loaders.append(loader) 216 | return loaders 217 | 218 | 219 | def iterate_dataloaders(dataloaders): 220 | """Alternatively generate data from multiple dataloaders, 221 | since we use `zip` to concat multiple dataloaders, 222 | the loop will end when the smaller dataloader runs out. 223 | 224 | Args: 225 | dataloaders List(DataLoader): can be a single or multiple dataloaders 226 | """ 227 | for data_tuples in zip(*dataloaders): 228 | for idx, data in enumerate(data_tuples): 229 | yield dataloaders[idx].dataset.media_type, data 230 | -------------------------------------------------------------------------------- /GUI_Vid/utils/logger.py: -------------------------------------------------------------------------------- 1 | # from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import functools 5 | import logging 6 | import os 7 | import sys 8 | import time 9 | import wandb 10 | from typing import Any, Dict, Union 11 | 12 | import torch 13 | from .distributed import get_rank, is_main_process 14 | from termcolor import colored 15 | 16 | 17 | def log_dict_to_wandb(log_dict, step, prefix=""): 18 | """include a separator `/` at the end of `prefix`""" 19 | if not is_main_process(): 20 | return 21 | 22 | log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()} 23 | wandb.log(log_dict, step) 24 | 25 | 26 | def setup_wandb(config): 27 | if not (config.wandb.enable and is_main_process()): 28 | return 29 | 30 | run = wandb.init( 31 | config=config, 32 | project=config.wandb.project, 33 | entity=config.wandb.entity, 34 | name=os.path.basename(config.output_dir), 35 | reinit=True 36 | ) 37 | return run 38 | 39 | 40 | def setup_output_folder(save_dir: str, folder_only: bool = False): 41 | """Sets up and returns the output file where the logs will be placed 42 | based on the configuration passed. Usually "save_dir/logs/log_.txt". 43 | If env.log_dir is passed, logs will be directly saved in this folder. 44 | Args: 45 | folder_only (bool, optional): If folder should be returned and not the file. 46 | Defaults to False. 47 | Returns: 48 | str: folder or file path depending on folder_only flag 49 | """ 50 | log_filename = "train_" 51 | log_filename += time.strftime("%Y_%m_%dT%H_%M_%S") 52 | log_filename += ".log" 53 | 54 | log_folder = os.path.join(save_dir, "logs") 55 | 56 | if not os.path.exists(log_folder): 57 | os.path.mkdirs(log_folder) 58 | 59 | if folder_only: 60 | return log_folder 61 | 62 | log_filename = os.path.join(log_folder, log_filename) 63 | 64 | return log_filename 65 | 66 | 67 | def setup_logger( 68 | output: str = None, 69 | color: bool = True, 70 | name: str = "mmf", 71 | disable: bool = False, 72 | clear_handlers=True, 73 | *args, 74 | **kwargs, 75 | ): 76 | """ 77 | Initialize the MMF logger and set its verbosity level to "INFO". 78 | Outside libraries shouldn't call this in case they have set there 79 | own logging handlers and setup. If they do, and don't want to 80 | clear handlers, pass clear_handlers options. 81 | The initial version of this function was taken from D2 and adapted 82 | for MMF. 83 | Args: 84 | output (str): a file name or a directory to save log. 85 | If ends with ".txt" or ".log", assumed to be a file name. 86 | Default: Saved to file 87 | color (bool): If false, won't log colored logs. Default: true 88 | name (str): the root module name of this logger. Defaults to "mmf". 89 | disable: do not use 90 | clear_handlers (bool): If false, won't clear existing handlers. 91 | Returns: 92 | logging.Logger: a logger 93 | """ 94 | if disable: 95 | return None 96 | logger = logging.getLogger(name) 97 | logger.propagate = False 98 | 99 | logging.captureWarnings(True) 100 | warnings_logger = logging.getLogger("py.warnings") 101 | 102 | plain_formatter = logging.Formatter( 103 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s", 104 | datefmt="%Y-%m-%dT%H:%M:%S", 105 | ) 106 | 107 | distributed_rank = get_rank() 108 | handlers = [] 109 | 110 | logging_level = logging.INFO 111 | # logging_level = logging.DEBUG 112 | 113 | if distributed_rank == 0: 114 | logger.setLevel(logging_level) 115 | ch = logging.StreamHandler(stream=sys.stdout) 116 | ch.setLevel(logging_level) 117 | if color: 118 | formatter = ColorfulFormatter( 119 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", 120 | datefmt="%Y-%m-%dT%H:%M:%S", 121 | ) 122 | else: 123 | formatter = plain_formatter 124 | ch.setFormatter(formatter) 125 | logger.addHandler(ch) 126 | warnings_logger.addHandler(ch) 127 | handlers.append(ch) 128 | 129 | # file logging: all workers 130 | if output is None: 131 | output = setup_output_folder() 132 | 133 | if output is not None: 134 | if output.endswith(".txt") or output.endswith(".log"): 135 | filename = output 136 | else: 137 | filename = os.path.join(output, "train.log") 138 | if distributed_rank > 0: 139 | filename = filename + f".rank{distributed_rank}" 140 | os.makedirs(os.path.dirname(filename), exist_ok=True) 141 | 142 | fh = logging.StreamHandler(_cached_log_stream(filename)) 143 | fh.setLevel(logging_level) 144 | fh.setFormatter(plain_formatter) 145 | logger.addHandler(fh) 146 | warnings_logger.addHandler(fh) 147 | handlers.append(fh) 148 | 149 | # Slurm/FB output, only log the main process 150 | # save_dir = get_mmf_env(key="save_dir") 151 | if "train.log" not in filename and distributed_rank == 0: 152 | filename = os.path.join(output, "train.log") 153 | sh = logging.StreamHandler(_cached_log_stream(filename)) 154 | sh.setLevel(logging_level) 155 | sh.setFormatter(plain_formatter) 156 | logger.addHandler(sh) 157 | warnings_logger.addHandler(sh) 158 | handlers.append(sh) 159 | 160 | logger.info(f"Logging to: {filename}") 161 | 162 | # Remove existing handlers to add MMF specific handlers 163 | if clear_handlers: 164 | for handler in logging.root.handlers[:]: 165 | logging.root.removeHandler(handler) 166 | # Now, add our handlers. 167 | logging.basicConfig(level=logging_level, handlers=handlers) 168 | 169 | return logger 170 | 171 | 172 | def setup_very_basic_config(color=True): 173 | plain_formatter = logging.Formatter( 174 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s", 175 | datefmt="%Y-%m-%dT%H:%M:%S", 176 | ) 177 | ch = logging.StreamHandler(stream=sys.stdout) 178 | ch.setLevel(logging.INFO) 179 | if color: 180 | formatter = ColorfulFormatter( 181 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", 182 | datefmt="%Y-%m-%dT%H:%M:%S", 183 | ) 184 | else: 185 | formatter = plain_formatter 186 | ch.setFormatter(formatter) 187 | # Setup a minimal configuration for logging in case something tries to 188 | # log a message even before logging is setup by MMF. 189 | logging.basicConfig(level=logging.INFO, handlers=[ch]) 190 | 191 | 192 | # cache the opened file object, so that different calls to `setup_logger` 193 | # with the same file name can safely write to the same file. 194 | @functools.lru_cache(maxsize=None) 195 | def _cached_log_stream(filename): 196 | return open(filename, "a") 197 | 198 | 199 | # ColorfulFormatter is adopted from Detectron2 and adapted for MMF 200 | class ColorfulFormatter(logging.Formatter): 201 | def __init__(self, *args, **kwargs): 202 | super().__init__(*args, **kwargs) 203 | 204 | def formatMessage(self, record): 205 | log = super().formatMessage(record) 206 | if record.levelno == logging.WARNING: 207 | prefix = colored("WARNING", "red", attrs=["blink"]) 208 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 209 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 210 | else: 211 | return log 212 | return prefix + " " + log 213 | 214 | 215 | class TensorboardLogger: 216 | def __init__(self, log_folder="./logs", iteration=0): 217 | # This would handle warning of missing tensorboard 218 | from torch.utils.tensorboard import SummaryWriter 219 | 220 | self.summary_writer = None 221 | self._is_master = is_main_process() 222 | # self.timer = Timer() 223 | self.log_folder = log_folder 224 | 225 | if self._is_master: 226 | # current_time = self.timer.get_time_hhmmss(None, format=self.time_format) 227 | current_time = time.strftime("%Y-%m-%dT%H:%M:%S") 228 | # self.timer.get_time_hhmmss(None, format=self.time_format) 229 | tensorboard_folder = os.path.join( 230 | self.log_folder, f"tensorboard_{current_time}" 231 | ) 232 | self.summary_writer = SummaryWriter(tensorboard_folder) 233 | 234 | def __del__(self): 235 | if getattr(self, "summary_writer", None) is not None: 236 | self.summary_writer.close() 237 | 238 | def _should_log_tensorboard(self): 239 | if self.summary_writer is None or not self._is_master: 240 | return False 241 | else: 242 | return True 243 | 244 | def add_scalar(self, key, value, iteration): 245 | if not self._should_log_tensorboard(): 246 | return 247 | 248 | self.summary_writer.add_scalar(key, value, iteration) 249 | 250 | def add_scalars(self, scalar_dict, iteration): 251 | if not self._should_log_tensorboard(): 252 | return 253 | 254 | for key, val in scalar_dict.items(): 255 | self.summary_writer.add_scalar(key, val, iteration) 256 | 257 | def add_histogram_for_model(self, model, iteration): 258 | if not self._should_log_tensorboard(): 259 | return 260 | 261 | for name, param in model.named_parameters(): 262 | np_param = param.clone().cpu().data.numpy() 263 | self.summary_writer.add_histogram(name, np_param, iteration) 264 | -------------------------------------------------------------------------------- /GUI_Vid/utils/basic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import os 4 | import json 5 | import logging 6 | import random 7 | import time 8 | from collections import defaultdict, deque 9 | import datetime 10 | from pathlib import Path 11 | from typing import List, Union 12 | 13 | import torch 14 | import torch.distributed as dist 15 | from .distributed import is_dist_avail_and_initialized 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class SmoothedValue(object): 22 | """Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window=20, fmt=None): 27 | if fmt is None: 28 | fmt = "{median:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | if not is_dist_avail_and_initialized(): 44 | return 45 | t = torch.tensor([self.count, self.total], 46 | dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | if meter.count == 0: # skip empty meter 108 | loss_str.append( 109 | "{}: {}".format(name, "No data") 110 | ) 111 | else: 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def global_avg(self): 118 | loss_str = [] 119 | for name, meter in self.meters.items(): 120 | if meter.count == 0: 121 | loss_str.append( 122 | "{}: {}".format(name, "No data") 123 | ) 124 | else: 125 | loss_str.append( 126 | "{}: {:.4f}".format(name, meter.global_avg) 127 | ) 128 | return self.delimiter.join(loss_str) 129 | 130 | def get_global_avg_dict(self, prefix=""): 131 | """include a separator (e.g., `/`, or "_") at the end of `prefix`""" 132 | d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()} 133 | return d 134 | 135 | def synchronize_between_processes(self): 136 | for meter in self.meters.values(): 137 | meter.synchronize_between_processes() 138 | 139 | def add_meter(self, name, meter): 140 | self.meters[name] = meter 141 | 142 | def log_every(self, iterable, log_freq, header=None): 143 | i = 0 144 | if not header: 145 | header = '' 146 | start_time = time.time() 147 | end = time.time() 148 | iter_time = SmoothedValue(fmt='{avg:.4f}') 149 | data_time = SmoothedValue(fmt='{avg:.4f}') 150 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 151 | log_msg = [ 152 | header, 153 | '[{0' + space_fmt + '}/{1}]', 154 | 'eta: {eta}', 155 | '{meters}', 156 | 'time: {time}', 157 | 'data: {data}' 158 | ] 159 | if torch.cuda.is_available(): 160 | log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}') 161 | log_msg = self.delimiter.join(log_msg) 162 | MB = 1024.0 * 1024.0 163 | for obj in iterable: 164 | data_time.update(time.time() - end) 165 | yield obj 166 | iter_time.update(time.time() - end) 167 | if i % log_freq == 0 or i == len(iterable) - 1: 168 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 169 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 170 | if torch.cuda.is_available(): 171 | logger.info(log_msg.format( 172 | i, len(iterable), eta=eta_string, 173 | meters=str(self), 174 | time=str(iter_time), data=str(data_time), 175 | memory=torch.cuda.max_memory_allocated() / MB, 176 | res_mem=torch.cuda.max_memory_reserved() / MB, 177 | )) 178 | else: 179 | logger.info(log_msg.format( 180 | i, len(iterable), eta=eta_string, 181 | meters=str(self), 182 | time=str(iter_time), data=str(data_time))) 183 | i += 1 184 | end = time.time() 185 | total_time = time.time() - start_time 186 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 187 | logger.info('{} Total time: {} ({:.4f} s / it)'.format( 188 | header, total_time_str, total_time / len(iterable))) 189 | 190 | 191 | class AttrDict(dict): 192 | def __init__(self, *args, **kwargs): 193 | super(AttrDict, self).__init__(*args, **kwargs) 194 | self.__dict__ = self 195 | 196 | 197 | def compute_acc(logits, label, reduction='mean'): 198 | ret = (torch.argmax(logits, dim=1) == label).float() 199 | if reduction == 'none': 200 | return ret.detach() 201 | elif reduction == 'mean': 202 | return ret.mean().item() 203 | 204 | 205 | def compute_n_params(model, return_str=True): 206 | tot = 0 207 | for p in model.parameters(): 208 | w = 1 209 | for x in p.shape: 210 | w *= x 211 | tot += w 212 | if return_str: 213 | if tot >= 1e6: 214 | return '{:.1f}M'.format(tot / 1e6) 215 | else: 216 | return '{:.1f}K'.format(tot / 1e3) 217 | else: 218 | return tot 219 | 220 | 221 | def setup_seed(seed): 222 | torch.manual_seed(seed) 223 | np.random.seed(seed) 224 | random.seed(seed) 225 | 226 | 227 | def remove_files_if_exist(file_paths): 228 | for fp in file_paths: 229 | if os.path.isfile(fp): 230 | os.remove(fp) 231 | 232 | 233 | def save_json(data, filename, save_pretty=False, sort_keys=False): 234 | with open(filename, "w") as f: 235 | if save_pretty: 236 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) 237 | else: 238 | json.dump(data, f) 239 | 240 | 241 | def load_json(filename): 242 | with open(filename, "r") as f: 243 | return json.load(f) 244 | 245 | 246 | def flat_list_of_lists(l): 247 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" 248 | return [item for sublist in l for item in sublist] 249 | 250 | 251 | def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]): 252 | """ 253 | Args: 254 | root: path to the directory to start search files 255 | suffix: any str as suffix, or can match multiple such strings 256 | when input is List[str]. 257 | Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`] 258 | Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`. 259 | """ 260 | if isinstance(suffix, str): 261 | suffix = [suffix, ] 262 | filepaths = flat_list_of_lists( 263 | [list(Path(root).rglob(f"*{e}")) for e in suffix]) 264 | return filepaths 265 | 266 | 267 | def match_key_and_shape(state_dict1, state_dict2): 268 | keys1 = set(state_dict1.keys()) 269 | keys2 = set(state_dict2.keys()) 270 | print(f"keys1 - keys2: {keys1 - keys2}") 271 | print(f"keys2 - keys1: {keys2 - keys1}") 272 | 273 | mismatch = 0 274 | for k in list(keys1): 275 | if state_dict1[k].shape != state_dict2[k].shape: 276 | print( 277 | f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}") 278 | mismatch += 1 279 | print(f"mismatch {mismatch}") 280 | 281 | 282 | def merge_dicts(list_dicts): 283 | merged_dict = list_dicts[0].copy() 284 | for i in range(1, len(list_dicts)): 285 | merged_dict.update(list_dicts[i]) 286 | return merged_dict 287 | -------------------------------------------------------------------------------- /GUI_Vid/models/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy import interpolate 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def _init_transformer_weights(module, initializer_range=0.02): 13 | """Initialize the weights. Copied from transformers ViT/Bert model init""" 14 | if isinstance(module, (nn.Linear, nn.Conv2d)): 15 | # Slightly different from the TF version which uses truncated_normal for initialization 16 | # cf https://github.com/pytorch/pytorch/pull/5617 17 | module.weight.data.normal_(mean=0.0, std=initializer_range) 18 | if module.bias is not None: 19 | module.bias.data.zero_() 20 | elif isinstance(module, nn.Embedding): 21 | module.weight.data.normal_(mean=0.0, std=initializer_range) 22 | if module.padding_idx is not None: 23 | module.weight.data[module.padding_idx].zero_() 24 | elif isinstance(module, nn.LayerNorm): 25 | module.bias.data.zero_() 26 | module.weight.data.fill_(1.0) 27 | 28 | 29 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 30 | """ 31 | Add/Remove extra temporal_embeddings as needed. 32 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 33 | 34 | temp_embed_old: (1, num_frames_old, 1, d) 35 | temp_embed_new: (1, num_frames_new, 1, d) 36 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 37 | """ 38 | # TODO zero pad 39 | num_frms_new = temp_embed_new.shape[1] 40 | num_frms_old = temp_embed_old.shape[1] 41 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 42 | if num_frms_new > num_frms_old: 43 | if add_zero: 44 | temp_embed_new[ 45 | :, :num_frms_old 46 | ] = temp_embed_old # untrained embeddings are zeros. 47 | else: 48 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 49 | elif num_frms_new < num_frms_old: 50 | temp_embed_new = temp_embed_old[:, :num_frms_new] 51 | else: # = 52 | temp_embed_new = temp_embed_old 53 | return temp_embed_new 54 | 55 | 56 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 57 | """ 58 | Add/Remove extra temporal_embeddings as needed. 59 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 60 | 61 | temp_embed_old: (1, num_frames_old, 1, d) 62 | temp_embed_new: (1, num_frames_new, 1, d) 63 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 64 | """ 65 | # TODO zero pad 66 | num_frms_new = temp_embed_new.shape[1] 67 | num_frms_old = temp_embed_old.shape[1] 68 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 69 | if num_frms_new > num_frms_old: 70 | if add_zero: 71 | temp_embed_new[ 72 | :, :num_frms_old 73 | ] = temp_embed_old # untrained embeddings are zeros. 74 | else: 75 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 76 | elif num_frms_new < num_frms_old: 77 | temp_embed_new = temp_embed_old[:, :num_frms_new] 78 | else: # = 79 | temp_embed_new = temp_embed_old 80 | return temp_embed_new 81 | 82 | 83 | def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new): 84 | """ 85 | temp_embed_old: (1, num_frames_old, 1, d) 86 | Returns: 87 | temp_embed_new: (1, num_frames_new, 1, d) 88 | """ 89 | temp_embed_old = temp_embed_old.squeeze(2).permute( 90 | 0, 2, 1 91 | ) # (1, d, num_frames_old) 92 | temp_embed_new = F.interpolate( 93 | temp_embed_old, num_frames_new, mode="linear" 94 | ) # (1, d, num_frames_new) 95 | temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze( 96 | 2 97 | ) # (1, num_frames_new, 1, d) 98 | return temp_embed_new 99 | 100 | 101 | def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new): 102 | """ 103 | Args: 104 | pos_embed_old: (1, L_old, d), pre-trained 105 | pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights 106 | num_patches_new: 107 | """ 108 | # interpolate position embedding 109 | embedding_size = pos_embed_old.shape[-1] 110 | num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new 111 | # height (== width) for the checkpoint position embedding 112 | orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5) 113 | # height (== width) for the new position embedding 114 | new_size = int(num_patches_new ** 0.5) 115 | 116 | if orig_size != new_size: 117 | # class_token and dist_token are kept unchanged 118 | # the extra tokens seems always at the beginning of the position embedding 119 | extra_tokens = pos_embed_old[:, :num_extra_tokens] 120 | # only the position tokens are interpolated 121 | pos_tokens = pos_embed_old[:, num_extra_tokens:] 122 | pos_tokens = pos_tokens.reshape( 123 | -1, orig_size, orig_size, embedding_size 124 | ).permute(0, 3, 1, 2) 125 | pos_tokens = torch.nn.functional.interpolate( 126 | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False 127 | ) 128 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 129 | interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 130 | logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2") 131 | return interpolated_pos_embed 132 | else: 133 | return pos_embed_old 134 | 135 | 136 | def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new): 137 | """ 138 | Args: 139 | state_dict_old: loaded state dict 140 | state_dict_new: state dict for model with new image size 141 | patch_shape_new: new model patch_shape 142 | ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py 143 | """ 144 | all_keys = list(state_dict_old.keys()) 145 | for key in all_keys: 146 | if "relative_position_index" in key: 147 | state_dict_old.pop(key) 148 | 149 | if "relative_position_bias_table" in key: 150 | rel_pos_bias = state_dict_old[key] 151 | src_num_pos, num_attn_heads = rel_pos_bias.size() 152 | dst_num_pos, _ = state_dict_new[key].size() 153 | dst_patch_shape = patch_shape_new 154 | if dst_patch_shape[0] != dst_patch_shape[1]: 155 | raise NotImplementedError() 156 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( 157 | dst_patch_shape[1] * 2 - 1 158 | ) 159 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 160 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 161 | if src_size != dst_size: 162 | # logger.info("Position interpolate for %s from %dx%d to %dx%d" % ( 163 | # key, src_size, src_size, dst_size, dst_size)) 164 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 165 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 166 | 167 | def geometric_progression(a, r, n): 168 | return a * (1.0 - r ** n) / (1.0 - r) 169 | 170 | left, right = 1.01, 1.5 171 | while right - left > 1e-6: 172 | q = (left + right) / 2.0 173 | gp = geometric_progression(1, q, src_size // 2) 174 | if gp > dst_size // 2: 175 | right = q 176 | else: 177 | left = q 178 | 179 | # if q > 1.090307: 180 | # q = 1.090307 181 | 182 | dis = [] 183 | cur = 1 184 | for i in range(src_size // 2): 185 | dis.append(cur) 186 | cur += q ** (i + 1) 187 | 188 | r_ids = [-_ for _ in reversed(dis)] 189 | 190 | x = r_ids + [0] + dis 191 | y = r_ids + [0] + dis 192 | 193 | t = dst_size // 2.0 194 | dx = np.arange(-t, t + 0.1, 1.0) 195 | dy = np.arange(-t, t + 0.1, 1.0) 196 | 197 | # logger.info("Original positions = %s" % str(x)) 198 | # logger.info("Target positions = %s" % str(dx)) 199 | 200 | all_rel_pos_bias = [] 201 | 202 | for i in range(num_attn_heads): 203 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 204 | f = interpolate.interp2d(x, y, z, kind="cubic") 205 | all_rel_pos_bias.append( 206 | torch.Tensor(f(dx, dy)) 207 | .contiguous() 208 | .view(-1, 1) 209 | .to(rel_pos_bias.device) 210 | ) 211 | 212 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 213 | 214 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 215 | state_dict_old[key] = new_rel_pos_bias 216 | return state_dict_old 217 | 218 | 219 | def tile(x, dim, n_tile): 220 | init_dim = x.size(dim) 221 | repeat_idx = [1] * x.dim() 222 | repeat_idx[dim] = n_tile 223 | x = x.repeat(*repeat_idx) 224 | order_index = torch.LongTensor( 225 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 226 | ) 227 | return torch.index_select(x, dim, order_index.to(x.device)) 228 | 229 | 230 | def mask_logits(target, mask): 231 | return target * mask + (1 - mask) * (-1e10) 232 | 233 | 234 | class AllGather(torch.autograd.Function): 235 | """An autograd function that performs allgather on a tensor.""" 236 | 237 | @staticmethod 238 | def forward(ctx, tensor, args): 239 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 240 | torch.distributed.all_gather(output, tensor) 241 | ctx.rank = args.rank 242 | ctx.batch_size = tensor.shape[0] 243 | return torch.cat(output, dim=0) 244 | 245 | @staticmethod 246 | def backward(ctx, grad_output): 247 | return ( 248 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 249 | None, 250 | ) 251 | 252 | 253 | allgather_wgrad = AllGather.apply 254 | -------------------------------------------------------------------------------- /GUI_Vid/models/blip2/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy import interpolate 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def _init_transformer_weights(module, initializer_range=0.02): 13 | """Initialize the weights. Copied from transformers ViT/Bert model init""" 14 | if isinstance(module, (nn.Linear, nn.Conv2d)): 15 | # Slightly different from the TF version which uses truncated_normal for initialization 16 | # cf https://github.com/pytorch/pytorch/pull/5617 17 | module.weight.data.normal_(mean=0.0, std=initializer_range) 18 | if module.bias is not None: 19 | module.bias.data.zero_() 20 | elif isinstance(module, nn.Embedding): 21 | module.weight.data.normal_(mean=0.0, std=initializer_range) 22 | if module.padding_idx is not None: 23 | module.weight.data[module.padding_idx].zero_() 24 | elif isinstance(module, nn.LayerNorm): 25 | module.bias.data.zero_() 26 | module.weight.data.fill_(1.0) 27 | 28 | 29 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 30 | """ 31 | Add/Remove extra temporal_embeddings as needed. 32 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 33 | 34 | temp_embed_old: (1, num_frames_old, 1, d) 35 | temp_embed_new: (1, num_frames_new, 1, d) 36 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 37 | """ 38 | # TODO zero pad 39 | num_frms_new = temp_embed_new.shape[1] 40 | num_frms_old = temp_embed_old.shape[1] 41 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 42 | if num_frms_new > num_frms_old: 43 | if add_zero: 44 | temp_embed_new[ 45 | :, :num_frms_old 46 | ] = temp_embed_old # untrained embeddings are zeros. 47 | else: 48 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 49 | elif num_frms_new < num_frms_old: 50 | temp_embed_new = temp_embed_old[:, :num_frms_new] 51 | else: # = 52 | temp_embed_new = temp_embed_old 53 | return temp_embed_new 54 | 55 | 56 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 57 | """ 58 | Add/Remove extra temporal_embeddings as needed. 59 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 60 | 61 | temp_embed_old: (1, num_frames_old, 1, d) 62 | temp_embed_new: (1, num_frames_new, 1, d) 63 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 64 | """ 65 | # TODO zero pad 66 | num_frms_new = temp_embed_new.shape[1] 67 | num_frms_old = temp_embed_old.shape[1] 68 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 69 | if num_frms_new > num_frms_old: 70 | if add_zero: 71 | temp_embed_new[ 72 | :, :num_frms_old 73 | ] = temp_embed_old # untrained embeddings are zeros. 74 | else: 75 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 76 | elif num_frms_new < num_frms_old: 77 | temp_embed_new = temp_embed_old[:, :num_frms_new] 78 | else: # = 79 | temp_embed_new = temp_embed_old 80 | return temp_embed_new 81 | 82 | 83 | def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new): 84 | """ 85 | temp_embed_old: (1, num_frames_old, 1, d) 86 | Returns: 87 | temp_embed_new: (1, num_frames_new, 1, d) 88 | """ 89 | temp_embed_old = temp_embed_old.squeeze(2).permute( 90 | 0, 2, 1 91 | ) # (1, d, num_frames_old) 92 | temp_embed_new = F.interpolate( 93 | temp_embed_old, num_frames_new, mode="linear" 94 | ) # (1, d, num_frames_new) 95 | temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze( 96 | 2 97 | ) # (1, num_frames_new, 1, d) 98 | return temp_embed_new 99 | 100 | 101 | def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new): 102 | """ 103 | Args: 104 | pos_embed_old: (1, L_old, d), pre-trained 105 | pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights 106 | num_patches_new: 107 | """ 108 | # interpolate position embedding 109 | embedding_size = pos_embed_old.shape[-1] 110 | num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new 111 | # height (== width) for the checkpoint position embedding 112 | orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5) 113 | # height (== width) for the new position embedding 114 | new_size = int(num_patches_new ** 0.5) 115 | 116 | if orig_size != new_size: 117 | # class_token and dist_token are kept unchanged 118 | # the extra tokens seems always at the beginning of the position embedding 119 | extra_tokens = pos_embed_old[:, :num_extra_tokens] 120 | # only the position tokens are interpolated 121 | pos_tokens = pos_embed_old[:, num_extra_tokens:] 122 | pos_tokens = pos_tokens.reshape( 123 | -1, orig_size, orig_size, embedding_size 124 | ).permute(0, 3, 1, 2) 125 | pos_tokens = torch.nn.functional.interpolate( 126 | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False 127 | ) 128 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 129 | interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 130 | logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2") 131 | return interpolated_pos_embed 132 | else: 133 | return pos_embed_old 134 | 135 | 136 | def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new): 137 | """ 138 | Args: 139 | state_dict_old: loaded state dict 140 | state_dict_new: state dict for model with new image size 141 | patch_shape_new: new model patch_shape 142 | ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py 143 | """ 144 | all_keys = list(state_dict_old.keys()) 145 | for key in all_keys: 146 | if "relative_position_index" in key: 147 | state_dict_old.pop(key) 148 | 149 | if "relative_position_bias_table" in key: 150 | rel_pos_bias = state_dict_old[key] 151 | src_num_pos, num_attn_heads = rel_pos_bias.size() 152 | dst_num_pos, _ = state_dict_new[key].size() 153 | dst_patch_shape = patch_shape_new 154 | if dst_patch_shape[0] != dst_patch_shape[1]: 155 | raise NotImplementedError() 156 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( 157 | dst_patch_shape[1] * 2 - 1 158 | ) 159 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 160 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 161 | if src_size != dst_size: 162 | # logger.info("Position interpolate for %s from %dx%d to %dx%d" % ( 163 | # key, src_size, src_size, dst_size, dst_size)) 164 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 165 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 166 | 167 | def geometric_progression(a, r, n): 168 | return a * (1.0 - r ** n) / (1.0 - r) 169 | 170 | left, right = 1.01, 1.5 171 | while right - left > 1e-6: 172 | q = (left + right) / 2.0 173 | gp = geometric_progression(1, q, src_size // 2) 174 | if gp > dst_size // 2: 175 | right = q 176 | else: 177 | left = q 178 | 179 | # if q > 1.090307: 180 | # q = 1.090307 181 | 182 | dis = [] 183 | cur = 1 184 | for i in range(src_size // 2): 185 | dis.append(cur) 186 | cur += q ** (i + 1) 187 | 188 | r_ids = [-_ for _ in reversed(dis)] 189 | 190 | x = r_ids + [0] + dis 191 | y = r_ids + [0] + dis 192 | 193 | t = dst_size // 2.0 194 | dx = np.arange(-t, t + 0.1, 1.0) 195 | dy = np.arange(-t, t + 0.1, 1.0) 196 | 197 | # logger.info("Original positions = %s" % str(x)) 198 | # logger.info("Target positions = %s" % str(dx)) 199 | 200 | all_rel_pos_bias = [] 201 | 202 | for i in range(num_attn_heads): 203 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 204 | f = interpolate.interp2d(x, y, z, kind="cubic") 205 | all_rel_pos_bias.append( 206 | torch.Tensor(f(dx, dy)) 207 | .contiguous() 208 | .view(-1, 1) 209 | .to(rel_pos_bias.device) 210 | ) 211 | 212 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 213 | 214 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 215 | state_dict_old[key] = new_rel_pos_bias 216 | return state_dict_old 217 | 218 | 219 | def tile(x, dim, n_tile): 220 | init_dim = x.size(dim) 221 | repeat_idx = [1] * x.dim() 222 | repeat_idx[dim] = n_tile 223 | x = x.repeat(*repeat_idx) 224 | order_index = torch.LongTensor( 225 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 226 | ) 227 | return torch.index_select(x, dim, order_index.to(x.device)) 228 | 229 | 230 | def mask_logits(target, mask): 231 | return target * mask + (1 - mask) * (-1e10) 232 | 233 | 234 | class AllGather(torch.autograd.Function): 235 | """An autograd function that performs allgather on a tensor.""" 236 | 237 | @staticmethod 238 | def forward(ctx, tensor, args): 239 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 240 | torch.distributed.all_gather(output, tensor) 241 | ctx.rank = args.rank 242 | ctx.batch_size = tensor.shape[0] 243 | return torch.cat(output, dim=0) 244 | 245 | @staticmethod 246 | def backward(ctx, grad_output): 247 | return ( 248 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 249 | None, 250 | ) 251 | 252 | 253 | allgather_wgrad = AllGather.apply 254 | -------------------------------------------------------------------------------- /GUI_Vid/demo_local.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['http_proxy'] = 'http://127.0.0.1:7890' 3 | os.environ['https_proxy'] = 'http://127.0.0.1:7890' 4 | os.environ['no_proxy'] = '127.0.0.1,localhost' 5 | os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' 6 | os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' 7 | os.environ['NO_PROXY'] = '127.0.0.1,localhost' 8 | 9 | from utils.config import Config 10 | config_file = "configs/config.json" 11 | cfg = Config.from_file(config_file) 12 | 13 | import io 14 | from argparse import ArgumentParser 15 | from models.videochat2_it import VideoChat2_it 16 | from utils.easydict import EasyDict 17 | import torch 18 | 19 | from transformers import StoppingCriteria, StoppingCriteriaList 20 | 21 | from PIL import Image 22 | import numpy as np 23 | import numpy as np 24 | from decord import VideoReader, cpu 25 | import torchvision.transforms as T 26 | from dataset.video_transforms import ( 27 | GroupNormalize, GroupScale, GroupCenterCrop, 28 | Stack, ToTorchFormatTensor 29 | ) 30 | from torchvision.transforms.functional import InterpolationMode 31 | 32 | from torchvision import transforms 33 | 34 | import matplotlib.pyplot as plt 35 | 36 | from IPython.display import Video, HTML 37 | 38 | from peft import get_peft_model, LoraConfig, TaskType 39 | import copy 40 | 41 | 42 | parser = ArgumentParser() 43 | parser.add_argument("--ckpt_path", type=str, default=None) 44 | parser.add_argument("--qs", type=str, default=None) 45 | parser.add_argument("--video_path", type=str, default=None) 46 | parser.add_argument("--keyframes", type=int, default=None) 47 | args = parser.parse_args() 48 | # load stage2 model 49 | cfg.model.vision_encoder.num_frames = args.keyframes 50 | model = VideoChat2_it(config=cfg.model) 51 | 52 | # add lora to run stage3 model 53 | peft_config = LoraConfig( 54 | task_type=TaskType.CAUSAL_LM, inference_mode=False, 55 | r=16, lora_alpha=32, lora_dropout=0. 56 | ) 57 | 58 | model.llama_model = get_peft_model(model.llama_model, peft_config) 59 | state_dict = torch.load(args.ckpt_path, "cuda") 60 | 61 | if 'model' in state_dict.keys(): 62 | msg = model.load_state_dict(state_dict['model'], strict=False) 63 | else: 64 | msg = model.load_state_dict(state_dict, strict=False) 65 | print(msg) 66 | 67 | model = model.eval() 68 | 69 | def get_prompt(conv): 70 | ret = conv.system + conv.sep 71 | for role, message in conv.messages: 72 | if message: 73 | ret += role + ": " + message + conv.sep 74 | else: 75 | ret += role + ":" 76 | return ret 77 | 78 | 79 | def get_prompt2(conv): 80 | ret = conv.system + conv.sep 81 | count = 0 82 | for role, message in conv.messages: 83 | count += 1 84 | if count == len(conv.messages): 85 | ret += role + ": " + message 86 | else: 87 | if message: 88 | ret += role + ": " + message + conv.sep 89 | else: 90 | ret += role + ":" 91 | return ret 92 | 93 | 94 | def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False): 95 | if answer_prompt: 96 | prompt = get_prompt2(conv) 97 | else: 98 | prompt = get_prompt(conv) 99 | if print_res: 100 | print(prompt) 101 | if '' in prompt: 102 | prompt_segs = prompt.split('') 103 | else: 104 | prompt_segs = prompt.split('') 105 | assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." 106 | with torch.no_grad(): 107 | seg_tokens = [ 108 | model.llama_tokenizer( 109 | seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids 110 | # only add bos to the first seg 111 | for i, seg in enumerate(prompt_segs) 112 | ] 113 | seg_embs = [model.llama_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens] 114 | mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] 115 | mixed_embs = torch.cat(mixed_embs, dim=1) 116 | return mixed_embs 117 | 118 | 119 | def ask(text, conv): 120 | conv.messages.append([conv.roles[0], text + '\n']) 121 | 122 | 123 | class StoppingCriteriaSub(StoppingCriteria): 124 | def __init__(self, stops=[], encounters=1): 125 | super().__init__() 126 | self.stops = stops 127 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 128 | for stop in self.stops: 129 | if torch.all((stop == input_ids[0][-len(stop):])).item(): 130 | return True 131 | return False 132 | 133 | 134 | def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, 135 | repetition_penalty=1.2, length_penalty=1.2, temperature=1.0, answer_prompt=None, print_res=False): 136 | stop_words_ids = [ 137 | torch.tensor([835]).to("cuda:0"), 138 | torch.tensor([2277, 29937]).to("cuda:0")] # '###' can be encoded in two different ways. 139 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 140 | 141 | conv.messages.append([conv.roles[1], answer_prompt]) 142 | embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res) 143 | with torch.no_grad(): 144 | outputs = model.llama_model.generate( 145 | inputs_embeds=embs, 146 | max_new_tokens=max_new_tokens, 147 | stopping_criteria=stopping_criteria, 148 | num_beams=num_beams, 149 | do_sample=do_sample, 150 | min_length=min_length, 151 | top_p=top_p, 152 | repetition_penalty=repetition_penalty, 153 | length_penalty=length_penalty, 154 | temperature=temperature, 155 | ) 156 | output_token = outputs[0] 157 | if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it 158 | output_token = output_token[1:] 159 | if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it 160 | output_token = output_token[1:] 161 | output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False) 162 | output_text = output_text.split('###')[0] # remove the stop sign '###' 163 | output_text = output_text.split('Assistant:')[-1].strip() 164 | conv.messages[-1][1] = output_text 165 | return output_text, output_token.cpu().numpy() 166 | 167 | def get_index(num_frames, num_segments): 168 | seg_size = float(num_frames - 1) / num_segments 169 | start = int(seg_size / 2) 170 | offsets = np.array([ 171 | start + int(np.round(seg_size * idx)) for idx in range(num_segments) 172 | ]) 173 | return offsets 174 | 175 | 176 | def load_video(video_path, num_segments=8, return_msg=False, resolution=224): 177 | vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) 178 | num_frames = len(vr) 179 | frame_indices = get_index(num_frames, num_segments) 180 | 181 | # transform 182 | crop_size = resolution 183 | scale_size = resolution 184 | input_mean = [0.48145466, 0.4578275, 0.40821073] 185 | input_std = [0.26862954, 0.26130258, 0.27577711] 186 | 187 | transform = T.Compose([ 188 | GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), 189 | GroupCenterCrop(crop_size), 190 | Stack(), 191 | ToTorchFormatTensor(), 192 | GroupNormalize(input_mean, input_std) 193 | ]) 194 | 195 | images_group = list() 196 | for frame_index in frame_indices: 197 | img = Image.fromarray(vr[frame_index].numpy()) 198 | images_group.append(img) 199 | torch_imgs = transform(images_group) 200 | if return_msg: 201 | fps = float(vr.get_avg_fps()) 202 | sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) 203 | # " " should be added in the start and end 204 | msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." 205 | return torch_imgs, msg 206 | else: 207 | return torch_imgs 208 | 209 | def get_sinusoid_encoding_table(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784): 210 | ''' Sinusoid position encoding table ''' 211 | # TODO: make it with torch instead of numpy 212 | def get_position_angle_vec(position): 213 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 214 | 215 | # generate checkpoint position embedding 216 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)]) 217 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 218 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 219 | sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0) 220 | 221 | print(f"n_position: {n_position}") 222 | print(f"pre_n_position: {pre_n_position}") 223 | 224 | if n_position != pre_n_position: 225 | T = ckpt_num_frame # checkpoint frame 226 | P = 14 # checkpoint size 227 | C = d_hid 228 | new_P = int((n_position // cur_frame) ** 0.5) # testing size 229 | if new_P != 14: 230 | print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}') 231 | print(f'Interpolate the position embedding') 232 | sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C) 233 | sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2) 234 | sinusoid_table = torch.nn.functional.interpolate( 235 | sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False) 236 | # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C 237 | sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C) 238 | sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C 239 | 240 | if cur_frame != ckpt_num_frame: 241 | print(f'Pretraining uses 4 frames, but current frame is {cur_frame}') 242 | print(f'Interpolate the position embedding') 243 | T = ckpt_num_frame # checkpoint frame 244 | new_T = cur_frame # testing frame 245 | # interpolate 246 | P = int((n_position // cur_frame) ** 0.5) # testing size 247 | C = d_hid 248 | sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C) 249 | sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T 250 | sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear') 251 | sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C 252 | sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C 253 | 254 | return sinusoid_table 255 | 256 | # num_frame = 8 257 | num_frame = args.keyframes 258 | # resolution = 384 259 | resolution = 224 260 | try: 261 | vid, msg = load_video(args.video_path, num_segments=num_frame, return_msg=True, resolution=resolution) 262 | new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame) 263 | model.vision_encoder.encoder.pos_embed = new_pos_emb 264 | 265 | print(msg) 266 | except Exception as e: 267 | pass 268 | 269 | # The model expects inputs of shape: T x C x H x W 270 | TC, H, W = vid.shape 271 | video = vid.reshape(1, TC//3, 3, H, W).to("cuda:0") 272 | 273 | model.cuda() 274 | video.cuda() 275 | img_list = [] 276 | with torch.no_grad(): 277 | image_emb, _ = model.encode_img(video, "Watch the video and follow the user's instruction.") 278 | # image_emb, _ = model.encode_img(video, "") 279 | 280 | img_list.append(image_emb) 281 | 282 | chat = EasyDict({ 283 | "system": "", 284 | "roles": ("Human", "Assistant"), 285 | "messages": [], 286 | "sep": "###" 287 | }) 288 | 289 | chat.messages.append([chat.roles[0], f"\n"]) 290 | ask(args.qs, chat) 291 | 292 | llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=1024, print_res=True)[0] 293 | print(llm_message) 294 | -------------------------------------------------------------------------------- /GUI_Vid/models/videochat2_pt.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | 4 | import torch 5 | from torch.cuda.amp import autocast as autocast 6 | import torch.nn as nn 7 | 8 | from .blip2.blip2 import Blip2Base, disabled_train 9 | from transformers import LlamaTokenizer, LlamaConfig 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class VideoChat2_pt(Blip2Base): 15 | """ 16 | VideoChat2 model. 17 | """ 18 | def __init__(self, config): 19 | super().__init__() 20 | # pretrained_path 21 | vit_blip_model_path = config.get("vit_blip_model_path", None) 22 | llama_model_path = config.get("llama_model_path") 23 | freeze_vit = config.get("freeze_vit", True) 24 | freeze_qformer = config.get("freeze_qformer", True) 25 | # vit 26 | low_resource = config.get("low_resource", False) # use 8 bit and put vit in cpu 27 | # qformer 28 | num_query_token = config.get("num_query_token") 29 | qformer_hidden_dropout_prob = config.get("qformer_hidden_dropout_prob", 0.1) 30 | qformer_attention_probs_dropout_prob = config.get("qformer_attention_probs_dropout_prob", 0.1) 31 | qformer_drop_path_rate = config.get("qformer_drop_path_rate", 0.1) 32 | extra_num_query_token = config.get("extra_num_query_token", 32) 33 | # prompt 34 | prompt_path = config.get("prompt_path", "") 35 | img_prompt_path = config.get("img_prompt_path", "") 36 | prompt_template = config.get("prompt_template", "") 37 | max_txt_len = config.get("max_txt_len", 32) 38 | end_sym = config.get("end_sym", '\n') 39 | # debug 40 | debug = config.get("debug", False) 41 | use_flash_attention = config.get("use_flash_attention", False) 42 | 43 | self.tokenizer = self.init_tokenizer(truncation_side="left") 44 | self.low_resource = low_resource 45 | self.vision_encoder, self.vision_layernorm, = self.init_vision_encoder_umt(config) 46 | self.qformer, self.query_tokens = self.init_Qformer( 47 | num_query_token, config.vision_encoder.encoder_embed_dim, 48 | qformer_hidden_dropout_prob=qformer_hidden_dropout_prob, 49 | qformer_attention_probs_dropout_prob=qformer_attention_probs_dropout_prob, 50 | qformer_drop_path_rate=qformer_drop_path_rate, 51 | ) 52 | self.qformer.bert.embeddings.word_embeddings = None 53 | self.qformer.bert.embeddings.position_embeddings = None 54 | for layer in self.qformer.bert.encoder.layer: 55 | layer.output = None 56 | layer.intermediate = None 57 | self.qformer.cls = None 58 | 59 | if vit_blip_model_path: 60 | logger.info(f"Load ViT and QFormer from {vit_blip_model_path}") 61 | state_dict = torch.load(vit_blip_model_path, map_location="cpu") 62 | msg = self.load_state_dict(state_dict, strict=False) 63 | logger.info(msg) 64 | logger.info('Loading ViT and Q-Former Done') 65 | 66 | self.extra_num_query_token = extra_num_query_token 67 | if extra_num_query_token > 0: 68 | logger.info(f"Add extra {extra_num_query_token} tokens in QFormer") 69 | self.extra_query_tokens = nn.Parameter( 70 | torch.zeros(1, extra_num_query_token, self.query_tokens.shape[-1]) 71 | ) 72 | 73 | if freeze_vit: 74 | logger.info("freeze vision encoder") 75 | for _, param in self.vision_encoder.named_parameters(): 76 | param.requires_grad = False 77 | self.vision_encoder = self.vision_encoder.eval() 78 | self.vision_encoder.train = disabled_train 79 | for _, param in self.vision_layernorm.named_parameters(): 80 | param.requires_grad = False 81 | self.vision_layernorm = self.vision_layernorm.eval() 82 | self.vision_layernorm.train = disabled_train 83 | 84 | if freeze_qformer: 85 | logger.info("freeze qformer") 86 | for _, param in self.qformer.named_parameters(): 87 | param.requires_grad = False 88 | self.qformer = self.qformer.eval() 89 | self.qformer.train = disabled_train 90 | self.query_tokens.requires_grad = False 91 | 92 | logger.info('Loading LLAMA') 93 | # problem: do we need to set truncation_side="left"? 94 | self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) 95 | if not self.llama_tokenizer.pad_token: 96 | logger.info("Set pad_token") 97 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token 98 | 99 | if use_flash_attention: 100 | logger.info("Use flash attention") 101 | from .blip2.modeling_llama_mem import LlamaForCausalLM 102 | else: 103 | from .blip2.modeling_llama import LlamaForCausalLM 104 | if debug: 105 | logger.info("Debug mode, build small LLAMA") 106 | llama_config = LlamaConfig.from_pretrained(llama_model_path) 107 | llama_config.hidden_size = 512 108 | llama_config.intermediate_size = 2048 109 | llama_config.num_attention_heads = 8 110 | llama_config.num_hidden_layers = 12 111 | llama_config.torch_dtype = torch.float16 112 | self.llama_model = LlamaForCausalLM(llama_config) 113 | else: 114 | if self.low_resource: 115 | self.llama_model = LlamaForCausalLM.from_pretrained( 116 | llama_model_path, 117 | torch_dtype=torch.float16, 118 | load_in_8bit=True, 119 | device_map="auto" 120 | ) 121 | else: 122 | self.llama_model = LlamaForCausalLM.from_pretrained( 123 | llama_model_path, 124 | torch_dtype=torch.float16, 125 | ) 126 | 127 | logger.info("freeze LLAMA") 128 | for _, param in self.llama_model.named_parameters(): 129 | param.requires_grad = False 130 | logger.info('Loading LLAMA Done') 131 | 132 | self.llama_proj = nn.Linear( 133 | self.qformer.config.hidden_size, self.llama_model.config.hidden_size 134 | ) 135 | self.max_txt_len = max_txt_len 136 | self.end_sym = end_sym 137 | 138 | if prompt_path: 139 | self.prompt_list = self.process_prompt(prompt_path, prompt_template) 140 | else: 141 | self.prompt_list = [] 142 | if img_prompt_path: 143 | self.img_prompt_list = self.process_prompt(img_prompt_path, prompt_template) 144 | else: 145 | self.img_prompt_list = [] 146 | 147 | def process_prompt(self, prompt_path, prompt_template): 148 | with open(prompt_path, 'r') as f: 149 | raw_prompts = f.read().splitlines() 150 | filted_prompts = [raw_prompt for raw_prompt in raw_prompts] 151 | prompt_list = [prompt_template.format(p) for p in filted_prompts] 152 | logger.info(f'Load {len(prompt_list)} training prompts') 153 | logger.info(f'Prompt: {prompt_list}') 154 | return prompt_list 155 | 156 | def vit_to_cpu(self): 157 | self.vision_layernorm.to("cpu") 158 | self.vision_layernorm.float() 159 | self.vision_encoder.to("cpu") 160 | self.vision_encoder.float() 161 | 162 | def encode_img(self, image): 163 | device = image.device 164 | if self.low_resource: 165 | self.vit_to_cpu() 166 | image = image.to("cpu") 167 | 168 | with self.maybe_autocast(): 169 | T = image.shape[1] 170 | use_image = True if T == 1 else False 171 | image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W] 172 | 173 | image_embeds = self.vision_encoder(image, use_image) 174 | B, T, L, C = image_embeds.shape 175 | image_embeds = image_embeds.reshape(B, -1, C) 176 | image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C] 177 | 178 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) 179 | 180 | if self.extra_num_query_token > 0: 181 | query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1) 182 | else: 183 | query_tokens = self.query_tokens 184 | query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1) 185 | 186 | query_output = self.qformer.bert( 187 | query_embeds=query_tokens, 188 | encoder_hidden_states=image_embeds, 189 | encoder_attention_mask=image_atts, 190 | return_dict=True, 191 | ) 192 | 193 | inputs_llama = self.llama_proj(query_output.last_hidden_state) 194 | atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) 195 | return inputs_llama, atts_llama 196 | 197 | def prompt_wrap(self, img_embeds, atts_img, prompt, use_image=False): 198 | if prompt: 199 | batch_size = img_embeds.shape[0] 200 | if use_image: 201 | p_before, p_after = prompt.split('') 202 | else: 203 | p_before, p_after = prompt.split('') 204 | p_before_tokens = self.llama_tokenizer( 205 | p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) 206 | p_after_tokens = self.llama_tokenizer( 207 | p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) 208 | p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) 209 | p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) 210 | wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) 211 | wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) 212 | return wrapped_img_embeds, wrapped_atts_img 213 | else: 214 | return img_embeds, atts_img 215 | 216 | def forward(self, image, text_input): 217 | T = image.shape[1] 218 | use_image = True if T == 1 else False 219 | if self.prompt_list: 220 | if use_image: 221 | prompt = random.choice(self.img_prompt_list) 222 | else: 223 | prompt = random.choice(self.prompt_list) 224 | 225 | img_embeds, atts_img = self.encode_img(image) 226 | 227 | if self.prompt_list: 228 | img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt, use_image) 229 | 230 | self.llama_tokenizer.padding_side = "right" 231 | text = [t + self.end_sym for t in text_input] 232 | 233 | to_regress_tokens = self.llama_tokenizer( 234 | text, 235 | return_tensors="pt", 236 | padding="longest", 237 | truncation=True, 238 | max_length=self.max_txt_len, 239 | add_special_tokens=False 240 | ).to(image.device) 241 | 242 | targets = to_regress_tokens.input_ids.masked_fill( 243 | to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 244 | ) 245 | 246 | empty_targets = ( 247 | torch.ones([atts_img.shape[0], atts_img.shape[1]+1], 248 | dtype=torch.long).to(image.device).fill_(-100) # plus one for bos 249 | ) 250 | targets = torch.cat([empty_targets, targets], dim=1) 251 | 252 | batch_size = img_embeds.shape[0] 253 | bos = torch.ones([batch_size, 1], 254 | dtype=to_regress_tokens.input_ids.dtype, 255 | device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id 256 | bos_embeds = self.llama_model.model.embed_tokens(bos) 257 | atts_bos = atts_img[:, :1] 258 | 259 | to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) 260 | inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) 261 | attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) 262 | 263 | with self.maybe_autocast(): 264 | outputs = self.llama_model( 265 | inputs_embeds=inputs_embeds, 266 | attention_mask=attention_mask, 267 | return_dict=True, 268 | labels=targets, 269 | ) 270 | 271 | return dict( 272 | loss=outputs.loss, 273 | ) 274 | -------------------------------------------------------------------------------- /GUI_Vid/models/videochat2_it.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | 4 | import torch 5 | from torch.cuda.amp import autocast as autocast 6 | import torch.nn as nn 7 | from peft import get_peft_model, LoraConfig, TaskType 8 | 9 | from .blip2.blip2 import Blip2Base, disabled_train 10 | from transformers import LlamaTokenizer, LlamaConfig 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class VideoChat2_it(Blip2Base): 16 | """ 17 | VideoChat2 model. 18 | """ 19 | def __init__(self, config): 20 | super().__init__() 21 | # pretrained_path 22 | vit_blip_model_path = config.get("vit_blip_model_path", None) 23 | llama_model_path = config.get("llama_model_path") 24 | videochat2_model_path = config.get("videochat2_model_path", "") 25 | freeze_vit = config.get("freeze_vit", True) 26 | freeze_qformer = config.get("freeze_qformer", True) 27 | # vit 28 | low_resource = config.get("low_resource", False) # use 8 bit and put vit in cpu 29 | # qformer 30 | num_query_token = config.get("num_query_token") 31 | qformer_hidden_dropout_prob = config.get("qformer_hidden_dropout_prob", 0.1) 32 | qformer_attention_probs_dropout_prob = config.get("qformer_attention_probs_dropout_prob", 0.1) 33 | qformer_drop_path_rate = config.get("qformer_drop_path_rate", 0.1) 34 | extra_num_query_token = config.get("extra_num_query_token", 32) 35 | self.qformer_text_input = config.get("qformer_text_input", True) 36 | # prompt 37 | max_txt_len = config.get("max_txt_len", 32) 38 | self.begin_signal = "###" 39 | self.role = ("Human", "Assistant") 40 | self.start_token = config.get("start_token", "") 42 | self.img_start_token = config.get("img_start_token", "") 43 | self.img_end_token = config.get("img_end_token", "") 44 | logger.info(f"Add instruction in qformer: {self.qformer_text_input}") 45 | # debug 46 | debug = config.get("debug", False) 47 | use_flash_attention = config.get("use_flash_attention", False) 48 | self.use_lora = config.get("use_lora", False) 49 | lora_r = config.get("lora_r", 8) 50 | lora_alpha = config.get("lora_alpha", 32) 51 | lora_dropout = config.get("lora_dropout", 0.05) 52 | 53 | self.tokenizer = self.init_tokenizer(truncation_side="left") 54 | self.low_resource = low_resource 55 | self.vision_encoder, self.vision_layernorm, = self.init_vision_encoder_umt(config) 56 | self.qformer, self.query_tokens = self.init_Qformer( 57 | num_query_token, config.vision_encoder.encoder_embed_dim, 58 | qformer_hidden_dropout_prob=qformer_hidden_dropout_prob, 59 | qformer_attention_probs_dropout_prob=qformer_attention_probs_dropout_prob, 60 | qformer_drop_path_rate=qformer_drop_path_rate, 61 | ) 62 | 63 | if not self.qformer_text_input: 64 | self.qformer.bert.embeddings.word_embeddings = None 65 | self.qformer.bert.embeddings.position_embeddings = None 66 | for layer in self.qformer.bert.encoder.layer: 67 | layer.output = None 68 | layer.intermediate = None 69 | else: 70 | self.qformer.resize_token_embeddings(len(self.tokenizer)) 71 | self.qformer.cls = None 72 | 73 | if vit_blip_model_path: 74 | logger.info(f"Load ViT and QFormer from {vit_blip_model_path}") 75 | state_dict = torch.load(vit_blip_model_path, map_location="cpu") 76 | msg = self.load_state_dict(state_dict, strict=False) 77 | logger.info(msg) 78 | logger.info('Loading ViT and Q-Former Done') 79 | 80 | self.extra_num_query_token = extra_num_query_token 81 | if extra_num_query_token > 0: 82 | logger.info(f"Add extra {extra_num_query_token} tokens in QFormer") 83 | self.extra_query_tokens = nn.Parameter( 84 | torch.zeros(1, extra_num_query_token, self.query_tokens.shape[-1]) 85 | ) 86 | 87 | if freeze_vit: 88 | logger.info("freeze vision encoder") 89 | for _, param in self.vision_encoder.named_parameters(): 90 | param.requires_grad = False 91 | self.vision_encoder = self.vision_encoder.eval() 92 | self.vision_encoder.train = disabled_train 93 | for _, param in self.vision_layernorm.named_parameters(): 94 | param.requires_grad = False 95 | self.vision_layernorm = self.vision_layernorm.eval() 96 | self.vision_layernorm.train = disabled_train 97 | 98 | if freeze_qformer: 99 | logger.info("freeze Qformer") 100 | for _, param in self.qformer.named_parameters(): 101 | param.requires_grad = False 102 | self.qformer = self.qformer.eval() 103 | self.qformer.train = disabled_train 104 | self.query_tokens.requires_grad = False 105 | 106 | logger.info('Loading LLAMA') 107 | # problem: do we need to set truncation_side="left"? 108 | self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) 109 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token 110 | 111 | if use_flash_attention: 112 | logger.info("Use flash attention") 113 | from .blip2.modeling_llama_mem import LlamaForCausalLM 114 | else: 115 | from .blip2.modeling_llama import LlamaForCausalLM 116 | if debug: 117 | logger.info("Debug mode, build small LLAMA") 118 | llama_config = LlamaConfig.from_pretrained(llama_model_path) 119 | llama_config.hidden_size = 512 120 | llama_config.intermediate_size = 2048 121 | llama_config.num_attention_heads = 8 122 | llama_config.num_hidden_layers = 12 123 | llama_config.torch_dtype = torch.float16 124 | self.llama_model = LlamaForCausalLM(llama_config) 125 | else: 126 | if self.low_resource: 127 | self.llama_model = LlamaForCausalLM.from_pretrained( 128 | llama_model_path, 129 | torch_dtype=torch.float16, 130 | load_in_8bit=True, 131 | device_map="auto", 132 | ) 133 | else: 134 | self.llama_model = LlamaForCausalLM.from_pretrained( 135 | llama_model_path, 136 | torch_dtype=torch.float16, 137 | ) 138 | 139 | logger.info("freeze LLAMA") 140 | for name, param in self.llama_model.named_parameters(): 141 | param.requires_grad = False 142 | logger.info('Loading LLAMA Done') 143 | 144 | if self.use_lora: 145 | logger.info("Use lora") 146 | peft_config = LoraConfig( 147 | task_type=TaskType.CAUSAL_LM, inference_mode=False, 148 | r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout 149 | ) 150 | self.llama_model = get_peft_model(self.llama_model, peft_config) 151 | self.llama_model.print_trainable_parameters() 152 | 153 | self.llama_proj = nn.Linear( 154 | self.qformer.config.hidden_size, self.llama_model.config.hidden_size 155 | ) 156 | self.max_txt_len = max_txt_len 157 | 158 | # load weights of VideoChat2 159 | if videochat2_model_path: 160 | logger.info(f"Load VideoChat2 from: {videochat2_model_path}") 161 | ckpt = torch.load(videochat2_model_path, map_location="cpu") 162 | if 'model' in ckpt.keys(): 163 | msg = self.load_state_dict(ckpt['model'], strict=False) 164 | else: 165 | msg = self.load_state_dict(ckpt, strict=False) 166 | logger.info(msg) 167 | 168 | def vit_to_cpu(self): 169 | self.vision_layernorm.to("cpu") 170 | self.vision_layernorm.float() 171 | self.vision_encoder.to("cpu") 172 | self.vision_encoder.float() 173 | 174 | def encode_img(self, image, instruction): 175 | device = image.device 176 | if self.low_resource: 177 | self.vit_to_cpu() 178 | image = image.to("cpu") 179 | 180 | with self.maybe_autocast(): 181 | T = image.shape[1] 182 | use_image = True if T == 1 else False 183 | image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W] 184 | 185 | image_embeds = self.vision_encoder(image, use_image).to("cuda") 186 | B, T, L, C = image_embeds.shape 187 | image_embeds = image_embeds.reshape(B, -1, C) 188 | image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C] 189 | 190 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) 191 | 192 | if self.extra_num_query_token > 0: 193 | query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1) 194 | else: 195 | query_tokens = self.query_tokens 196 | query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1) 197 | if self.qformer_text_input: 198 | text_Qformer = self.tokenizer( 199 | instruction, 200 | padding='longest', 201 | truncation=True, 202 | max_length=self.max_txt_len, 203 | return_tensors="pt", 204 | ).to(image_embeds.device) 205 | query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) 206 | Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) 207 | 208 | query_output = self.qformer.bert( 209 | text_Qformer.input_ids, 210 | attention_mask=Qformer_atts, 211 | query_embeds=query_tokens, 212 | encoder_hidden_states=image_embeds, 213 | encoder_attention_mask=image_atts, 214 | return_dict=True, 215 | ) 216 | else: 217 | query_output = self.qformer.bert( 218 | query_embeds=query_tokens, 219 | encoder_hidden_states=image_embeds, 220 | encoder_attention_mask=image_atts, 221 | return_dict=True, 222 | ) 223 | 224 | inputs_llama = self.llama_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :]) 225 | return inputs_llama, use_image 226 | 227 | def _get_text_len(self, text): 228 | return self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.shape[1] 229 | 230 | def forward(self, image, text_input, instruction): 231 | img_embeds, use_image = self.encode_img(image, instruction) 232 | batch_size, img_len, _ = img_embeds.shape 233 | 234 | # mark the largest length 235 | # when padding, the attention mask will be 0 236 | max_len = 0 237 | input_embed_list = [] 238 | p_before_len_list = [] 239 | target_list = [] 240 | # handle each prompt individually 241 | for idx, prompt in enumerate(text_input): 242 | tmp_img_embeds = img_embeds[idx].unsqueeze(0) 243 | # split the prompt via END_TOKEN 244 | end_token = self.img_end_token if use_image else self.end_token 245 | p_before, p_after = prompt.split(end_token) 246 | p_after = end_token + p_after 247 | p_before_tokens = self.llama_tokenizer(p_before, return_tensors="pt", add_special_tokens=False).to(tmp_img_embeds.device) 248 | p_after_tokens = self.llama_tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(tmp_img_embeds.device) 249 | if self.use_lora: 250 | p_before_embeds = self.llama_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids) 251 | p_after_embeds = self.llama_model.base_model.model.model.embed_tokens(p_after_tokens.input_ids) 252 | else: 253 | p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids) 254 | p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids) 255 | input_embeds = torch.cat([p_before_embeds, tmp_img_embeds, p_after_embeds], dim=1) 256 | 257 | # extract the answers and mask the target 258 | # the answers are only in the p_after 259 | sep1 = self.begin_signal + self.role[0] + ": " 260 | sep2 = self.begin_signal + self.role[1] + ": " 261 | raw_text = p_after.split(sep2) 262 | for idx in range(1, len(raw_text)): 263 | raw_text[idx] = sep2 + raw_text[idx] 264 | # the first raw_text contains system and question 265 | # the last raw_text only contains answer 266 | # rstrip() for the extra " " 267 | answer_targets = p_after_tokens.input_ids.clone() 268 | # target: "###Human: ###Assistant: xxxxx. ###" 269 | system = raw_text[0].split(sep1)[0] 270 | system_len = self._get_text_len(system.rstrip()) 271 | sep_len = self._get_text_len(sep1.rstrip()) 272 | cur_len = self._get_text_len(raw_text[0].rstrip()) 273 | answer_targets[:, :system_len] = -100 274 | answer_targets[:, (system_len+sep_len):cur_len] = -100 275 | for text in raw_text[1:-1]: 276 | total_len = self._get_text_len(text.rstrip()) 277 | ans_len = self._get_text_len((text.split(sep1)[0]+sep1).rstrip()) 278 | answer_targets[:, (cur_len+ans_len):(cur_len+total_len)] = -100 279 | cur_len += total_len 280 | cur_len += self._get_text_len(raw_text[-1].rstrip()) 281 | assert cur_len == answer_targets.shape[1], f"The final length ({cur_len}) is not equal to the original prompt ({answer_targets.shape[1]}): {prompt}" 282 | 283 | max_len = max(max_len, input_embeds.shape[1]) 284 | input_embed_list.append(input_embeds) 285 | p_before_len_list.append(p_before_tokens.input_ids.shape[1]) 286 | target_list.append(answer_targets) 287 | 288 | # plus one for bos 289 | # max_txt_len plus num_query_token is the max len 290 | txt_len = min(max_len + 1, self.max_txt_len + img_len) 291 | inputs_embeds = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) * self.llama_tokenizer.pad_token_id 292 | if self.use_lora: 293 | inputs_embeds = self.llama_model.base_model.model.model.embed_tokens(inputs_embeds) 294 | else: 295 | inputs_embeds = self.llama_model.model.embed_tokens(inputs_embeds) 296 | attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) 297 | targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100) 298 | # set bos_token 299 | inputs_embeds[:, :1] = self.llama_tokenizer.bos_token_id 300 | for idx in range(batch_size): 301 | input_len = min(input_embed_list[idx].shape[1], txt_len - 1) 302 | # if less than txt_len, the input will be padding 303 | # if more than txt_len, the input will be truncated 304 | inputs_embeds[idx, 1:(input_len+1)] = input_embed_list[idx][:, :input_len] 305 | # the attention_mask is 0 when padding 306 | attention_mask[idx, :(input_len+1)] = 1 307 | # the target is -100 when padding 308 | p_before_len = p_before_len_list[idx] 309 | targets[idx, (p_before_len+img_len+1):(input_len+1)] = target_list[idx][0, :(input_len-p_before_len-img_len)] 310 | 311 | with self.maybe_autocast(): 312 | outputs = self.llama_model( 313 | inputs_embeds=inputs_embeds, 314 | attention_mask=attention_mask, 315 | return_dict=True, 316 | labels=targets, 317 | ) 318 | 319 | return dict( 320 | loss=outputs.loss, 321 | ) 322 | -------------------------------------------------------------------------------- /GUI_Vid/dataset/video_transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class MultiGroupRandomCrop(object): 38 | def __init__(self, size, groups=1): 39 | if isinstance(size, numbers.Number): 40 | self.size = (int(size), int(size)) 41 | else: 42 | self.size = size 43 | self.groups = groups 44 | 45 | def __call__(self, img_group): 46 | 47 | w, h = img_group[0].size 48 | th, tw = self.size 49 | 50 | out_images = list() 51 | 52 | for i in range(self.groups): 53 | x1 = random.randint(0, w - tw) 54 | y1 = random.randint(0, h - th) 55 | 56 | for img in img_group: 57 | assert(img.size[0] == w and img.size[1] == h) 58 | if w == tw and h == th: 59 | out_images.append(img) 60 | else: 61 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 62 | 63 | return out_images 64 | 65 | 66 | class GroupCenterCrop(object): 67 | def __init__(self, size): 68 | self.worker = torchvision.transforms.CenterCrop(size) 69 | 70 | def __call__(self, img_group): 71 | return [self.worker(img) for img in img_group] 72 | 73 | 74 | class GroupRandomHorizontalFlip(object): 75 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 76 | """ 77 | 78 | def __init__(self, is_flow=False): 79 | self.is_flow = is_flow 80 | 81 | def __call__(self, img_group, is_flow=False): 82 | v = random.random() 83 | if v < 0.5: 84 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 85 | if self.is_flow: 86 | for i in range(0, len(ret), 2): 87 | # invert flow pixel values when flipping 88 | ret[i] = ImageOps.invert(ret[i]) 89 | return ret 90 | else: 91 | return img_group 92 | 93 | 94 | class GroupNormalize(object): 95 | def __init__(self, mean, std): 96 | self.mean = mean 97 | self.std = std 98 | 99 | def __call__(self, tensor): 100 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 101 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 102 | 103 | # TODO: make efficient 104 | for t, m, s in zip(tensor, rep_mean, rep_std): 105 | t.sub_(m).div_(s) 106 | 107 | return tensor 108 | 109 | 110 | class GroupScale(object): 111 | """ Rescales the input PIL.Image to the given 'size'. 112 | 'size' will be the size of the smaller edge. 113 | For example, if height > width, then image will be 114 | rescaled to (size * height / width, size) 115 | size: size of the smaller edge 116 | interpolation: Default: PIL.Image.BILINEAR 117 | """ 118 | 119 | def __init__(self, size, interpolation=Image.BILINEAR): 120 | self.worker = torchvision.transforms.Resize(size, interpolation) 121 | 122 | def __call__(self, img_group): 123 | return [self.worker(img) for img in img_group] 124 | 125 | 126 | class GroupOverSample(object): 127 | def __init__(self, crop_size, scale_size=None, flip=True): 128 | self.crop_size = crop_size if not isinstance( 129 | crop_size, int) else (crop_size, crop_size) 130 | 131 | if scale_size is not None: 132 | self.scale_worker = GroupScale(scale_size) 133 | else: 134 | self.scale_worker = None 135 | self.flip = flip 136 | 137 | def __call__(self, img_group): 138 | 139 | if self.scale_worker is not None: 140 | img_group = self.scale_worker(img_group) 141 | 142 | image_w, image_h = img_group[0].size 143 | crop_w, crop_h = self.crop_size 144 | 145 | offsets = GroupMultiScaleCrop.fill_fix_offset( 146 | False, image_w, image_h, crop_w, crop_h) 147 | oversample_group = list() 148 | for o_w, o_h in offsets: 149 | normal_group = list() 150 | flip_group = list() 151 | for i, img in enumerate(img_group): 152 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 153 | normal_group.append(crop) 154 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 155 | 156 | if img.mode == 'L' and i % 2 == 0: 157 | flip_group.append(ImageOps.invert(flip_crop)) 158 | else: 159 | flip_group.append(flip_crop) 160 | 161 | oversample_group.extend(normal_group) 162 | if self.flip: 163 | oversample_group.extend(flip_group) 164 | return oversample_group 165 | 166 | 167 | class GroupFullResSample(object): 168 | def __init__(self, crop_size, scale_size=None, flip=True): 169 | self.crop_size = crop_size if not isinstance( 170 | crop_size, int) else (crop_size, crop_size) 171 | 172 | if scale_size is not None: 173 | self.scale_worker = GroupScale(scale_size) 174 | else: 175 | self.scale_worker = None 176 | self.flip = flip 177 | 178 | def __call__(self, img_group): 179 | 180 | if self.scale_worker is not None: 181 | img_group = self.scale_worker(img_group) 182 | 183 | image_w, image_h = img_group[0].size 184 | crop_w, crop_h = self.crop_size 185 | 186 | w_step = (image_w - crop_w) // 4 187 | h_step = (image_h - crop_h) // 4 188 | 189 | offsets = list() 190 | offsets.append((0 * w_step, 2 * h_step)) # left 191 | offsets.append((4 * w_step, 2 * h_step)) # right 192 | offsets.append((2 * w_step, 2 * h_step)) # center 193 | 194 | oversample_group = list() 195 | for o_w, o_h in offsets: 196 | normal_group = list() 197 | flip_group = list() 198 | for i, img in enumerate(img_group): 199 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 200 | normal_group.append(crop) 201 | if self.flip: 202 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 203 | 204 | if img.mode == 'L' and i % 2 == 0: 205 | flip_group.append(ImageOps.invert(flip_crop)) 206 | else: 207 | flip_group.append(flip_crop) 208 | 209 | oversample_group.extend(normal_group) 210 | oversample_group.extend(flip_group) 211 | return oversample_group 212 | 213 | 214 | class GroupMultiScaleCrop(object): 215 | 216 | def __init__(self, input_size, scales=None, max_distort=1, 217 | fix_crop=True, more_fix_crop=True): 218 | self.scales = scales if scales is not None else [1, .875, .75, .66] 219 | self.max_distort = max_distort 220 | self.fix_crop = fix_crop 221 | self.more_fix_crop = more_fix_crop 222 | self.input_size = input_size if not isinstance(input_size, int) else [ 223 | input_size, input_size] 224 | self.interpolation = Image.BILINEAR 225 | 226 | def __call__(self, img_group): 227 | 228 | im_size = img_group[0].size 229 | 230 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 231 | crop_img_group = [ 232 | img.crop( 233 | (offset_w, 234 | offset_h, 235 | offset_w + 236 | crop_w, 237 | offset_h + 238 | crop_h)) for img in img_group] 239 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 240 | for img in crop_img_group] 241 | return ret_img_group 242 | 243 | def _sample_crop_size(self, im_size): 244 | image_w, image_h = im_size[0], im_size[1] 245 | 246 | # find a crop size 247 | base_size = min(image_w, image_h) 248 | crop_sizes = [int(base_size * x) for x in self.scales] 249 | crop_h = [ 250 | self.input_size[1] if abs( 251 | x - self.input_size[1]) < 3 else x for x in crop_sizes] 252 | crop_w = [ 253 | self.input_size[0] if abs( 254 | x - self.input_size[0]) < 3 else x for x in crop_sizes] 255 | 256 | pairs = [] 257 | for i, h in enumerate(crop_h): 258 | for j, w in enumerate(crop_w): 259 | if abs(i - j) <= self.max_distort: 260 | pairs.append((w, h)) 261 | 262 | crop_pair = random.choice(pairs) 263 | if not self.fix_crop: 264 | w_offset = random.randint(0, image_w - crop_pair[0]) 265 | h_offset = random.randint(0, image_h - crop_pair[1]) 266 | else: 267 | w_offset, h_offset = self._sample_fix_offset( 268 | image_w, image_h, crop_pair[0], crop_pair[1]) 269 | 270 | return crop_pair[0], crop_pair[1], w_offset, h_offset 271 | 272 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 273 | offsets = self.fill_fix_offset( 274 | self.more_fix_crop, image_w, image_h, crop_w, crop_h) 275 | return random.choice(offsets) 276 | 277 | @staticmethod 278 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 279 | w_step = (image_w - crop_w) // 4 280 | h_step = (image_h - crop_h) // 4 281 | 282 | ret = list() 283 | ret.append((0, 0)) # upper left 284 | ret.append((4 * w_step, 0)) # upper right 285 | ret.append((0, 4 * h_step)) # lower left 286 | ret.append((4 * w_step, 4 * h_step)) # lower right 287 | ret.append((2 * w_step, 2 * h_step)) # center 288 | 289 | if more_fix_crop: 290 | ret.append((0, 2 * h_step)) # center left 291 | ret.append((4 * w_step, 2 * h_step)) # center right 292 | ret.append((2 * w_step, 4 * h_step)) # lower center 293 | ret.append((2 * w_step, 0 * h_step)) # upper center 294 | 295 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 296 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 297 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 298 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 299 | 300 | return ret 301 | 302 | 303 | class GroupRandomSizedCrop(object): 304 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 305 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 306 | This is popularly used to train the Inception networks 307 | size: size of the smaller edge 308 | interpolation: Default: PIL.Image.BILINEAR 309 | """ 310 | 311 | def __init__(self, size, interpolation=Image.BILINEAR): 312 | self.size = size 313 | self.interpolation = interpolation 314 | 315 | def __call__(self, img_group): 316 | for attempt in range(10): 317 | area = img_group[0].size[0] * img_group[0].size[1] 318 | target_area = random.uniform(0.08, 1.0) * area 319 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 320 | 321 | w = int(round(math.sqrt(target_area * aspect_ratio))) 322 | h = int(round(math.sqrt(target_area / aspect_ratio))) 323 | 324 | if random.random() < 0.5: 325 | w, h = h, w 326 | 327 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 328 | x1 = random.randint(0, img_group[0].size[0] - w) 329 | y1 = random.randint(0, img_group[0].size[1] - h) 330 | found = True 331 | break 332 | else: 333 | found = False 334 | x1 = 0 335 | y1 = 0 336 | 337 | if found: 338 | out_group = list() 339 | for img in img_group: 340 | img = img.crop((x1, y1, x1 + w, y1 + h)) 341 | assert(img.size == (w, h)) 342 | out_group.append( 343 | img.resize( 344 | (self.size, self.size), self.interpolation)) 345 | return out_group 346 | else: 347 | # Fallback 348 | scale = GroupScale(self.size, interpolation=self.interpolation) 349 | crop = GroupRandomCrop(self.size) 350 | return crop(scale(img_group)) 351 | 352 | 353 | class ConvertDataFormat(object): 354 | def __init__(self, model_type): 355 | self.model_type = model_type 356 | 357 | def __call__(self, images): 358 | if self.model_type == '2D': 359 | return images 360 | tc, h, w = images.size() 361 | t = tc // 3 362 | images = images.view(t, 3, h, w) 363 | images = images.permute(1, 0, 2, 3) 364 | return images 365 | 366 | 367 | class Stack(object): 368 | 369 | def __init__(self, roll=False): 370 | self.roll = roll 371 | 372 | def __call__(self, img_group): 373 | if img_group[0].mode == 'L': 374 | return np.concatenate([np.expand_dims(x, 2) 375 | for x in img_group], axis=2) 376 | elif img_group[0].mode == 'RGB': 377 | if self.roll: 378 | return np.concatenate([np.array(x)[:, :, ::-1] 379 | for x in img_group], axis=2) 380 | else: 381 | #print(np.concatenate(img_group, axis=2).shape) 382 | # print(img_group[0].shape) 383 | return np.concatenate(img_group, axis=2) 384 | 385 | 386 | class ToTorchFormatTensor(object): 387 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 388 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 389 | 390 | def __init__(self, div=True): 391 | self.div = div 392 | 393 | def __call__(self, pic): 394 | if isinstance(pic, np.ndarray): 395 | # handle numpy array 396 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 397 | else: 398 | # handle PIL Image 399 | img = torch.ByteTensor( 400 | torch.ByteStorage.from_buffer( 401 | pic.tobytes())) 402 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 403 | # put it from HWC to CHW format 404 | # yikes, this transpose takes 80% of the loading time/CPU 405 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 406 | return img.float().div(255) if self.div else img.float() 407 | 408 | 409 | class IdentityTransform(object): 410 | 411 | def __call__(self, data): 412 | return data 413 | 414 | 415 | if __name__ == "__main__": 416 | trans = torchvision.transforms.Compose([ 417 | GroupScale(256), 418 | GroupRandomCrop(224), 419 | Stack(), 420 | ToTorchFormatTensor(), 421 | GroupNormalize( 422 | mean=[.485, .456, .406], 423 | std=[.229, .224, .225] 424 | )] 425 | ) 426 | 427 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 428 | 429 | color_group = [im] * 3 430 | rst = trans(color_group) 431 | 432 | gray_group = [im.convert('L')] * 9 433 | gray_rst = trans(gray_group) 434 | 435 | trans2 = torchvision.transforms.Compose([ 436 | GroupRandomSizedCrop(256), 437 | Stack(), 438 | ToTorchFormatTensor(), 439 | GroupNormalize( 440 | mean=[.485, .456, .406], 441 | std=[.229, .224, .225]) 442 | ]) 443 | print(trans2(color_group)) 444 | --------------------------------------------------------------------------------