├── LICENSE
├── README.md
├── inf_video_llama.png
├── infty-Video-LLaMA
├── Gradio_demo
│ ├── GRADIO_DEMO.md
│ └── app_gradio.py
├── InfVideoLLaMA
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dist_utils.py
│ │ ├── gradcam.py
│ │ ├── logger.py
│ │ ├── optims.py
│ │ ├── registry.py
│ │ └── utils.py
│ ├── configs
│ │ ├── datasets
│ │ │ ├── cc_sbu
│ │ │ │ ├── align.yaml
│ │ │ │ └── defaults.yaml
│ │ │ ├── instruct
│ │ │ │ ├── llava_instruct.yaml
│ │ │ │ └── webvid_instruct.yaml
│ │ │ ├── laion
│ │ │ │ └── defaults.yaml
│ │ │ └── webvid
│ │ │ │ └── defaults.yaml
│ │ ├── default.yaml
│ │ └── models
│ │ │ ├── minigpt4.yaml
│ │ │ ├── moviechat.yaml
│ │ │ └── video_llama.yaml
│ ├── conversation
│ │ ├── __init__.py
│ │ └── conversation_video.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── builders
│ │ │ ├── __init__.py
│ │ │ ├── base_dataset_builder.py
│ │ │ ├── image_text_pair_builder.py
│ │ │ ├── instruct_builder.py
│ │ │ └── video_caption_builder.py
│ │ ├── data_utils.py
│ │ └── datasets
│ │ │ ├── __init__.py
│ │ │ ├── base_dataset.py
│ │ │ ├── caption_datasets.py
│ │ │ ├── cc_sbu_dataset.py
│ │ │ ├── dataloader_utils.py
│ │ │ ├── laion_dataset.py
│ │ │ ├── llava_instruct_dataset.py
│ │ │ ├── video_instruct_dataset.py
│ │ │ └── webvid_datasets.py
│ ├── models
│ │ ├── Qformer.py
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── basis_functions.py
│ │ ├── blip2.py
│ │ ├── blip2_outputs.py
│ │ ├── eva_vit.py
│ │ ├── eva_vit_with_tome.py
│ │ ├── helpers.py
│ │ ├── infinityqa.py
│ │ ├── long_term_attention.py
│ │ ├── long_term_attention_gibbs.py
│ │ ├── modeling_llama.py
│ │ ├── multimodal_preprocessors.py
│ │ └── process_video_data.py
│ ├── processors
│ │ ├── __init__.py
│ │ ├── base_processor.py
│ │ ├── blip_processors.py
│ │ ├── functional_video.py
│ │ ├── randaugment.py
│ │ ├── transforms_video.py
│ │ └── video_processor.py
│ ├── runners
│ │ ├── __init__.py
│ │ ├── runner_base.py
│ │ └── test.py
│ └── tasks
│ │ ├── __init__.py
│ │ ├── base_task.py
│ │ ├── image_text_pretrain.py
│ │ └── video_text_pretrain.py
├── apply_delta.py
├── convert_llama_to_hf.py
├── eval_code
│ ├── eval
│ │ ├── extract_features.py
│ │ ├── run_inference_inf_video_llama_egochema.py
│ │ ├── run_inference_inf_video_llama_egochema_full.py
│ │ ├── run_inference_inf_video_llama_moviechat.py
│ │ ├── run_inference_inf_video_llama_nextoe.py
│ │ ├── run_inference_inf_video_llama_nextqa.py
│ │ ├── run_inference_inf_video_llama_video_mme.py
│ │ └── utils.py
│ └── validate
│ │ ├── egoschema_acc.py
│ │ ├── run_eval.py
│ │ ├── run_eval_langchain.py
│ │ ├── run_eval_qa_chatgpt.py
│ │ ├── test.py
│ │ └── utils.py
├── eval_configs
│ └── infvideollama.yaml
├── inference.py
└── relevant_frames.py
└── infty-VideoChat2
├── configs
├── config.json
├── config_bert.json
├── config_mistral.json
├── config_phi.json
├── data.py
├── instruction_data.py
└── model.py
├── conversation.py
├── dataset
├── __init__.py
├── base_dataset.py
├── dataloader.py
├── hd_utils.py
├── it_dataset.py
├── it_dataset_mistral.py
├── it_dataset_phi.py
├── pt_dataset.py
├── sampler.py
├── utils.py
├── video_transforms.py
└── video_utils.py
├── eval_code
├── run_egoschema_mistral.py
├── run_egoschema_mistral_hd.py
├── run_moviechat_mistral.py
├── run_nextqa_mistral.py
└── run_videomme_mistral.py
├── models
├── __init__.py
├── bert
│ ├── __init__.py
│ ├── builder.py
│ ├── tokenization_bert.py
│ └── xbert.py
├── blip2
│ ├── Qformer.py
│ ├── Qformer_baseline.py
│ ├── __init__.py
│ ├── basis_functions.py
│ ├── blip2.py
│ ├── builder.py
│ ├── long_term_attention_gibbs.py
│ ├── modeling_llama.py
│ ├── modeling_llama_mem.py
│ ├── utils.py
│ └── vit.py
├── criterions.py
├── utils.py
├── videochat2_qformer.py
├── videochat_mistra
│ ├── __init__.py
│ ├── videochat2_it_hd_mistral.py
│ ├── videochat2_it_mistral.py
│ └── videochat2_pt_mistral.py
├── videochat_phi
│ ├── videochat2_it_phi.py
│ └── videochat2_pt_phi.py
└── videochat_vicuna
│ ├── __init__.py
│ ├── videochat2_it_vicuna.py
│ └── videochat2_pt_vicuna.py
├── scripts
├── videochat_mistral
│ ├── config_7b_hd_stage4.py
│ ├── config_7b_stage2.py
│ ├── config_7b_stage3.py
│ ├── run_7b_hd_stage4.sh
│ ├── run_7b_stage2.sh
│ ├── run_7b_stage3.sh
│ ├── slurm_run_7b_stage2.sh
│ ├── slurm_run_7b_stage3.sh
│ └── slurm_run_7b_stage4_hd.sh
├── videochat_phi
│ ├── config_7b_stage2.py
│ ├── config_7b_stage3.py
│ ├── run_7b_stage2.sh
│ └── run_7b_stage3.sh
└── videochat_vicuna
│ ├── config_7b_stage1.py
│ ├── config_7b_stage2.py
│ ├── config_7b_stage3.py
│ ├── run_7b_stage1.sh
│ ├── run_7b_stage2.sh
│ ├── run_7b_stage3.sh
│ ├── slurm_run_7b_stage1.sh
│ ├── slurm_run_7b_stage2.sh
│ └── slurm_run_7b_stage3.sh
├── tasks
├── retrieval_utils.py
├── shared_utils.py
├── shared_utils_ds.py
├── shared_utils_qformer.py
├── train_it.py
├── train_it_ds.py
├── train_pt.py
└── train_qformer.py
└── utils
├── basic_utils.py
├── config.py
├── config_utils.py
├── distributed.py
├── easydict.py
├── logger.py
├── optimizer.py
└── scheduler.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 DeepSPIN
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Infinite-Video
2 | # $\infty$-Video: A Training-Free Approach to Long Video Understanding via Continuous-Time Memory Consolidation
3 | Official implementation of the paper **$\infty$-Video: A Training-Free Approach to Long Video Understanding via Continuous-Time Memory Consolidation**.
4 |
5 | *Saul Santos*, *António Farinhas*, *Daniel McNamee* and *André Martins*
6 |
7 |
8 |
9 |
10 | **Abstract**: *Current video-language models struggle with long-video understanding due to limited context lengths and reliance on sparse frame subsampling, often leading to information loss.
11 | This paper introduces ∞-Video, which can process arbitrarily long videos through a continuous-time long-term memory (LTM) consolidation mechanism. Our framework augments video Q-formers by allowing them to process unbounded video contexts efficiently and without requiring additional training.
12 | Through continuous attention, our approach dynamically allocates higher granularity to the most relevant video segments, forming ``sticky'' memories that evolve over time.
13 | Experiments with Video-LLaMA and VideoChat2 demonstrate improved performance in video question-answering tasks, showcasing the potential of continuous-time LTM mechanisms to enable scalable and training-free comprehension of long videos.*
14 |
15 | ----------
16 |
17 | **If you use this code in your work, please cite our paper.**
18 |
19 | ----------
20 |
21 | ## Resources
22 |
23 | - [Paper](https://arxiv.org/abs/2501.19098) (arXiv)
24 |
25 | All material is made available under the MIT license. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made.
26 |
27 |
28 | ## Video LLaMA
29 | ### Python requirements and installation
30 |
31 | This code was tested on `Python 3.10.10`. To install, follow the steps of [moviechat](https://github.com/rese1f/MovieChat)
32 |
33 | ### Reproducibility
34 | 1 - Run ```eval_code/extract_features``` on the intended dataset with the desired number of frames.
35 |
36 | 2 - Run each script in ```eval_code/eval``` with the hyperparameters mentioned in the paper:
37 | Example:
38 | ```
39 | python3 eval_code/eval/run_inference_inf_video_llama_nextqa.py --cfg-path eval_configs/infvideollama.yaml --num-beams 1 --temperature 1 --video-folder next_qa/features --q-folder /mnt/scratch-artemis/saul/next_qa/val.csv --output-dir /MovieChat/nextqa_val --max_int 256 --num_basis 256 --tau 0.75 --alpha 1.0 --task inf_video_llama --sticky
40 | ```
41 |
42 | 3 - For open-ended questions run ```eval_code/validate/run_eval_qa_chatgpt.py```with the output of the moviechat script run.
43 |
44 | 4 - For multiple-choice questions, we predict the answers as open-ended and use langchain to select the most similar option, run ```eval_code/validate/run_eval_langchain.py```with the output from the dataset run script as:
45 |
46 | ```
47 | python eval_code/validate/run_eval_langchain.py --pred_path egoschema/nframes_8_nchunks_256_moviechatplus/preds.json --num_tasks 100
48 | ```
49 |
50 | Then compute accuracy with ```run_eval.py```
51 |
52 | ## VideoChat2
53 | ### Python requirements and installation
54 | Follow the instructions of [videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)
55 |
56 | ### Reproducibility
57 | 1 - Run each script in ```eval_code``` with the hyperparameters mentioned in the paper:
58 | Example:
59 | ```
60 | python3 eval_code/run_nextqa_mistral.py --video-folder /NExTQA/videos --data_path /next_qa/val.csv --output-dir nextqa_val --max_int 16 --num_samples 8 --num_basis 64 --tau 0.75 --alpha 1.0
61 |
62 | ```
63 |
64 | ## Acknowledgment
65 |
66 | The experiments in this work benefit from the following open-source codes:
67 | * Enxin Song, Wenhao Chai, Guanhong Wang, Yucheng Zhang, Haoyang Zhou, Feiyang Wu, Haozhe Chi, Xun Guo, Tian Ye, Yanting Zhang, Yan Lu, Jenq-Neng Hwang and Gaoang Wang. MovieChat: From Dense Token to Sparse Memory for Long Video Understanding, CVPR 2024. https://github.com/rese1f/MovieChat
68 | * Kunchang Li, Yali Wang, Yinan He, Yizhuo Li, Yi Wang, Yi Liu, Zun Wang, Jilan Xu, Guo Chen, Guo Chen, Limin Wanga and Yu Qiao. MVBench: A Comprehensive Multi-modal Video Understanding Benchmark, Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2
69 | * Pedro Henrique Martins, Zita Marinho, André F. T. Martins, ∞-former: Infinite Memory Transformer, Proc. ACL 2022. https://github.com/deep-spin/infinite-former
70 |
--------------------------------------------------------------------------------
/inf_video_llama.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/inf_video_llama.png
--------------------------------------------------------------------------------
/infty-Video-LLaMA/Gradio_demo/GRADIO_DEMO.md:
--------------------------------------------------------------------------------
1 | Run the inference demo with gradio. You'll need the following:
2 | Run the command:
3 |
4 | ```shell
5 | python app_gradio.py \
6 | --cfg-path eval_configs/MovieChat.yaml \
7 | --gpu-id 0 \
8 | ```
9 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | import sys
10 |
11 | from omegaconf import OmegaConf
12 |
13 | from InfVideoLLaMA.common.registry import registry
14 |
15 | from InfVideoLLaMA.datasets.builders import *
16 | from InfVideoLLaMA.models import *
17 | from InfVideoLLaMA.processors import *
18 | from InfVideoLLaMA.tasks import *
19 |
20 |
21 | root_dir = os.path.dirname(os.path.abspath(__file__))
22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23 |
24 | registry.register_path("library_root", root_dir)
25 | repo_root = os.path.join(root_dir, "..")
26 | registry.register_path("repo_root", repo_root)
27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28 | registry.register_path("cache_root", cache_root)
29 |
30 | registry.register("MAX_INT", sys.maxsize)
31 | registry.register("SPLIT_NAMES", ["train", "val", "test"])
32 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-Video-LLaMA/InfVideoLLaMA/common/__init__.py
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/common/dist_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import functools
10 | import os
11 |
12 | import torch
13 | import torch.distributed as dist
14 | import timm.models.hub as timm_hub
15 |
16 |
17 | def setup_for_distributed(is_master):
18 | """
19 | This function disables printing when not in master process
20 | """
21 | import builtins as __builtin__
22 |
23 | builtin_print = __builtin__.print
24 |
25 | def print(*args, **kwargs):
26 | force = kwargs.pop("force", False)
27 | if is_master or force:
28 | builtin_print(*args, **kwargs)
29 |
30 | __builtin__.print = print
31 |
32 |
33 | def is_dist_avail_and_initialized():
34 | if not dist.is_available():
35 | return False
36 | if not dist.is_initialized():
37 | return False
38 | return True
39 |
40 |
41 | def get_world_size():
42 | if not is_dist_avail_and_initialized():
43 | return 1
44 | return dist.get_world_size()
45 |
46 |
47 | def get_rank():
48 | if not is_dist_avail_and_initialized():
49 | return 0
50 | return dist.get_rank()
51 |
52 |
53 | def is_main_process():
54 | return get_rank() == 0
55 |
56 |
57 | def init_distributed_mode(args):
58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59 | args.rank = int(os.environ["RANK"])
60 | args.world_size = int(os.environ["WORLD_SIZE"])
61 | args.gpu = int(os.environ["LOCAL_RANK"])
62 | elif "SLURM_PROCID" in os.environ:
63 | args.rank = int(os.environ["SLURM_PROCID"])
64 | args.gpu = args.rank % torch.cuda.device_count()
65 | else:
66 | print("Not using distributed mode")
67 | args.distributed = False
68 | return
69 |
70 | args.distributed = True
71 |
72 | torch.cuda.set_device(args.gpu)
73 | args.dist_backend = "nccl"
74 | print(
75 | "| distributed init (rank {}, world {}): {}".format(
76 | args.rank, args.world_size, args.dist_url
77 | ),
78 | flush=True,
79 | )
80 | torch.distributed.init_process_group(
81 | backend=args.dist_backend,
82 | init_method=args.dist_url,
83 | world_size=args.world_size,
84 | rank=args.rank,
85 | timeout=datetime.timedelta(
86 | days=365
87 | ), # allow auto-downloading and de-compressing
88 | )
89 | torch.distributed.barrier()
90 | setup_for_distributed(args.rank == 0)
91 |
92 |
93 | def get_dist_info():
94 | if torch.__version__ < "1.0":
95 | initialized = dist._initialized
96 | else:
97 | initialized = dist.is_initialized()
98 | if initialized:
99 | rank = dist.get_rank()
100 | world_size = dist.get_world_size()
101 | else: # non-distributed training
102 | rank = 0
103 | world_size = 1
104 | return rank, world_size
105 |
106 |
107 | def main_process(func):
108 | @functools.wraps(func)
109 | def wrapper(*args, **kwargs):
110 | rank, _ = get_dist_info()
111 | if rank == 0:
112 | return func(*args, **kwargs)
113 |
114 | return wrapper
115 |
116 |
117 | def download_cached_file(url, check_hash=True, progress=False):
118 | """
119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121 | """
122 |
123 | def get_cached_file_path():
124 | # a hack to sync the file path across processes
125 | parts = torch.hub.urlparse(url)
126 | filename = os.path.basename(parts.path)
127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128 |
129 | return cached_file
130 |
131 | if is_main_process():
132 | timm_hub.download_cached_file(url, check_hash, progress)
133 |
134 | if is_dist_avail_and_initialized():
135 | dist.barrier()
136 |
137 | return get_cached_file_path()
138 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/common/gradcam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from scipy.ndimage import filters
4 | from skimage import transform as skimage_transform
5 |
6 |
7 | def getAttMap(img, attMap, blur=True, overlap=True):
8 | attMap -= attMap.min()
9 | if attMap.max() > 0:
10 | attMap /= attMap.max()
11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12 | if blur:
13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14 | attMap -= attMap.min()
15 | attMap /= attMap.max()
16 | cmap = plt.get_cmap("jet")
17 | attMapV = cmap(attMap)
18 | attMapV = np.delete(attMapV, 3, 2)
19 | if overlap:
20 | attMap = (
21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23 | )
24 | return attMap
25 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/common/optims.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import math
9 |
10 | from InfVideoLLaMA.common.registry import registry
11 |
12 |
13 | @registry.register_lr_scheduler("linear_warmup_step_lr")
14 | class LinearWarmupStepLRScheduler:
15 | def __init__(
16 | self,
17 | optimizer,
18 | max_epoch,
19 | min_lr,
20 | init_lr,
21 | decay_rate=1,
22 | warmup_start_lr=-1,
23 | warmup_steps=0,
24 | **kwargs
25 | ):
26 | self.optimizer = optimizer
27 |
28 | self.max_epoch = max_epoch
29 | self.min_lr = min_lr
30 |
31 | self.decay_rate = decay_rate
32 |
33 | self.init_lr = init_lr
34 | self.warmup_steps = warmup_steps
35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36 |
37 | def step(self, cur_epoch, cur_step):
38 | if cur_epoch == 0:
39 | warmup_lr_schedule(
40 | step=cur_step,
41 | optimizer=self.optimizer,
42 | max_step=self.warmup_steps,
43 | init_lr=self.warmup_start_lr,
44 | max_lr=self.init_lr,
45 | )
46 | else:
47 | step_lr_schedule(
48 | epoch=cur_epoch,
49 | optimizer=self.optimizer,
50 | init_lr=self.init_lr,
51 | min_lr=self.min_lr,
52 | decay_rate=self.decay_rate,
53 | )
54 |
55 |
56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57 | class LinearWarmupCosineLRScheduler:
58 | def __init__(
59 | self,
60 | optimizer,
61 | max_epoch,
62 | iters_per_epoch,
63 | min_lr,
64 | init_lr,
65 | warmup_steps=0,
66 | warmup_start_lr=-1,
67 | **kwargs
68 | ):
69 | self.optimizer = optimizer
70 |
71 | self.max_epoch = max_epoch
72 | self.iters_per_epoch = iters_per_epoch
73 | self.min_lr = min_lr
74 |
75 | self.init_lr = init_lr
76 | self.warmup_steps = warmup_steps
77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78 |
79 | def step(self, cur_epoch, cur_step):
80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81 | if total_cur_step < self.warmup_steps:
82 | warmup_lr_schedule(
83 | step=cur_step,
84 | optimizer=self.optimizer,
85 | max_step=self.warmup_steps,
86 | init_lr=self.warmup_start_lr,
87 | max_lr=self.init_lr,
88 | )
89 | else:
90 | cosine_lr_schedule(
91 | epoch=total_cur_step,
92 | optimizer=self.optimizer,
93 | max_epoch=self.max_epoch * self.iters_per_epoch,
94 | init_lr=self.init_lr,
95 | min_lr=self.min_lr,
96 | )
97 |
98 |
99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100 | """Decay the learning rate"""
101 | lr = (init_lr - min_lr) * 0.5 * (
102 | 1.0 + math.cos(math.pi * epoch / max_epoch)
103 | ) + min_lr
104 | for param_group in optimizer.param_groups:
105 | param_group["lr"] = lr
106 |
107 |
108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109 | """Warmup the learning rate"""
110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111 | for param_group in optimizer.param_groups:
112 | param_group["lr"] = lr
113 |
114 |
115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116 | """Decay the learning rate"""
117 | lr = max(min_lr, init_lr * (decay_rate**epoch))
118 | for param_group in optimizer.param_groups:
119 | param_group["lr"] = lr
120 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/datasets/cc_sbu/align.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu_align:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/cc_sbu_align_dataset
6 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/datasets/cc_sbu/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/cc_sbu_dataset/{00000..00001}.tar
6 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/datasets/instruct/llava_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | llava_instruct:
3 | data_type: image
4 | build_info:
5 | anno_dir: /path/llava_instruct_150k.json
6 | videos_dir: /path/train2014/train2014/
7 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/datasets/instruct/webvid_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | webvid_instruct:
3 | data_type: image
4 | build_info:
5 | anno_dir: /path/webvid_align/videochat_instruct_11k.json
6 | videos_dir: /path/webvid_align/videos/
7 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/datasets/laion/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | laion:
3 | data_type: images
4 | build_info:
5 | storage: path/laion/laion_dataset/{00000..00001}.tar
6 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/datasets/webvid/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | webvid:
3 | data_type: video
4 | build_info:
5 | anno_dir: path/webvid/webvid_tain_data/annotations/
6 | videos_dir: path//webvid/webvid_tain_data/videos/
7 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/default.yaml:
--------------------------------------------------------------------------------
1 | env:
2 | # For default users
3 | # cache_root: "cache"
4 | # For internal use with persistent storage
5 | cache_root: "/export/home/.cache/minigpt4"
6 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/models/minigpt4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # Vicuna
16 | llama_model: "ckpt/vicuna-13b/"
17 |
18 | # generation configs
19 | prompt: ""
20 |
21 | preprocess:
22 | vis_processor:
23 | train:
24 | name: "blip2_image_train"
25 | image_size: 224
26 | eval:
27 | name: "blip2_image_eval"
28 | image_size: 224
29 | text_processor:
30 | train:
31 | name: "blip_caption"
32 | eval:
33 | name: "blip_caption"
34 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/models/moviechat.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: moviechat
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # Vicuna
16 | llama_model: "/mnt/data-poseidon/saul/MovieChat/ckpt/MovieChat-vicuna"
17 |
18 | # generation configs
19 | prompt: ""
20 |
21 | preprocess:
22 | vis_processor:
23 | train:
24 | name: "alpro_video_train"
25 | image_size: 224
26 | n_frms: 8
27 | eval:
28 | name: "alpro_video_eval"
29 | image_size: 224
30 | n_frms: 8
31 | text_processor:
32 | train:
33 | name: "blip_caption"
34 | eval:
35 | name: "blip_caption"
36 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/configs/models/video_llama.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: video_llama
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # Vicuna
16 | llama_model: "/mnt/data-poseidon/saul/MovieChat/ckpt/MovieChat-vicuna"
17 |
18 | # generation configs
19 | prompt: ""
20 |
21 | preprocess:
22 | vis_processor:
23 | train:
24 | name: "alpro_video_train"
25 | image_size: 224
26 | n_frms: 8
27 | eval:
28 | name: "alpro_video_eval"
29 | image_size: 224
30 | n_frms: 8
31 | text_processor:
32 | train:
33 | name: "blip_caption"
34 | eval:
35 | name: "blip_caption"
36 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/conversation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-Video-LLaMA/InfVideoLLaMA/conversation/__init__.py
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-Video-LLaMA/InfVideoLLaMA/datasets/__init__.py
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/builders/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from InfVideoLLaMA.datasets.builders.base_dataset_builder import load_dataset_config
9 | from InfVideoLLaMA.datasets.builders.image_text_pair_builder import (
10 | CCSBUBuilder,
11 | LaionBuilder,
12 | CCSBUAlignBuilder
13 | )
14 | from InfVideoLLaMA.datasets.builders.video_caption_builder import WebvidBuilder
15 | from InfVideoLLaMA.common.registry import registry
16 | from InfVideoLLaMA.datasets.builders.instruct_builder import WebvidInstruct_Builder,LlavaInstruct_Builder
17 | __all__ = [
18 | "CCSBUBuilder",
19 | "LaionBuilder",
20 | "CCSBUAlignBuilder",
21 | "WebvidBuilder",
22 | "LlavaInstruct_Builder",
23 | "WebvidInstruct_Builder"
24 |
25 | ]
26 |
27 |
28 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
29 | """
30 | Example
31 |
32 | >>> dataset = load_dataset("coco_caption", cfg=None)
33 | >>> splits = dataset.keys()
34 | >>> print([len(dataset[split]) for split in splits])
35 |
36 | """
37 | if cfg_path is None:
38 | cfg = None
39 | else:
40 | cfg = load_dataset_config(cfg_path)
41 |
42 | try:
43 | builder = registry.get_builder_class(name)(cfg)
44 | except TypeError:
45 | print(
46 | f"Dataset {name} not found. Available datasets:\n"
47 | + ", ".join([str(k) for k in dataset_zoo.get_names()])
48 | )
49 | exit(1)
50 |
51 | if vis_path is not None:
52 | if data_type is None:
53 | # use default data type in the config
54 | data_type = builder.config.data_type
55 |
56 | assert (
57 | data_type in builder.config.build_info
58 | ), f"Invalid data_type {data_type} for {name}."
59 |
60 | builder.config.build_info.get(data_type).storage = vis_path
61 |
62 | dataset = builder.build_datasets()
63 | return dataset
64 |
65 |
66 | class DatasetZoo:
67 | def __init__(self) -> None:
68 | self.dataset_zoo = {
69 | k: list(v.DATASET_CONFIG_DICT.keys())
70 | for k, v in sorted(registry.mapping["builder_name_mapping"].items())
71 | }
72 |
73 | def get_names(self):
74 | return list(self.dataset_zoo.keys())
75 |
76 |
77 | dataset_zoo = DatasetZoo()
78 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/builders/image_text_pair_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 |
5 | from InfVideoLLaMA.common.registry import registry
6 | from InfVideoLLaMA.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7 | from InfVideoLLaMA.datasets.datasets.laion_dataset import LaionDataset
8 | from InfVideoLLaMA.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
9 |
10 |
11 | @registry.register_builder("cc_sbu")
12 | class CCSBUBuilder(BaseDatasetBuilder):
13 | train_dataset_cls = CCSBUDataset
14 |
15 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
16 |
17 | def _download_ann(self):
18 | pass
19 |
20 | def _download_vis(self):
21 | pass
22 |
23 | def build(self):
24 | self.build_processors()
25 |
26 | build_info = self.config.build_info
27 |
28 | datasets = dict()
29 | split = "train"
30 |
31 | # create datasets
32 | # [NOTE] return inner_datasets (wds.DataPipeline)
33 | dataset_cls = self.train_dataset_cls
34 | datasets[split] = dataset_cls(
35 | vis_processor=self.vis_processors[split],
36 | text_processor=self.text_processors[split],
37 | location=build_info.storage,
38 | ).inner_dataset
39 |
40 | return datasets
41 |
42 |
43 | @registry.register_builder("laion")
44 | class LaionBuilder(BaseDatasetBuilder):
45 | train_dataset_cls = LaionDataset
46 |
47 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
48 |
49 | def _download_ann(self):
50 | pass
51 |
52 | def _download_vis(self):
53 | pass
54 |
55 | def build(self):
56 | self.build_processors()
57 |
58 | build_info = self.config.build_info
59 |
60 | datasets = dict()
61 | split = "train"
62 |
63 | # create datasets
64 | # [NOTE] return inner_datasets (wds.DataPipeline)
65 | dataset_cls = self.train_dataset_cls
66 | datasets[split] = dataset_cls(
67 | vis_processor=self.vis_processors[split],
68 | text_processor=self.text_processors[split],
69 | location=build_info.storage,
70 | ).inner_dataset
71 |
72 | return datasets
73 |
74 |
75 | @registry.register_builder("cc_sbu_align")
76 | class CCSBUAlignBuilder(BaseDatasetBuilder):
77 | train_dataset_cls = CCSBUAlignDataset
78 |
79 | DATASET_CONFIG_DICT = {
80 | "default": "configs/datasets/cc_sbu/align.yaml",
81 | }
82 |
83 | def build_datasets(self):
84 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
85 | logging.info("Building datasets...")
86 | self.build_processors()
87 |
88 | build_info = self.config.build_info
89 | storage_path = build_info.storage
90 |
91 | datasets = dict()
92 |
93 | if not os.path.exists(storage_path):
94 | warnings.warn("storage path {} does not exist.".format(storage_path))
95 |
96 | # create datasets
97 | dataset_cls = self.train_dataset_cls
98 | datasets['train'] = dataset_cls(
99 | vis_processor=self.vis_processors["train"],
100 | text_processor=self.text_processors["train"],
101 | ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
102 | vis_root=os.path.join(storage_path, 'image'),
103 | )
104 |
105 | return datasets
106 |
107 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/builders/instruct_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 |
5 | from InfVideoLLaMA.common.registry import registry
6 | from InfVideoLLaMA.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7 | from InfVideoLLaMA.datasets.datasets.laion_dataset import LaionDataset
8 | from InfVideoLLaMA.datasets.datasets.llava_instruct_dataset import Instruct_Dataset
9 | from InfVideoLLaMA.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset
10 |
11 | @registry.register_builder("instruct")
12 | class Instruct_Builder(BaseDatasetBuilder):
13 | train_dataset_cls = Instruct_Dataset
14 |
15 | DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
16 |
17 | def _download_ann(self):
18 | pass
19 |
20 | def _download_vis(self):
21 | pass
22 |
23 | def build(self):
24 | self.build_processors()
25 | datasets = dict()
26 | split = "train"
27 |
28 | build_info = self.config.build_info
29 | dataset_cls = self.train_dataset_cls
30 | if self.config.num_video_query_token:
31 | num_video_query_token = self.config.num_video_query_token
32 | else:
33 | num_video_query_token = 32
34 |
35 | if self.config.tokenizer_name:
36 | tokenizer_name = self.config.tokenizer_name
37 | else:
38 | tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
39 |
40 |
41 | datasets[split] = dataset_cls(
42 | vis_processor=self.vis_processors[split],
43 | text_processor=self.text_processors[split],
44 | vis_root=build_info.videos_dir,
45 | ann_root=build_info.anno_dir,
46 | num_video_query_token = num_video_query_token,
47 | tokenizer_name = tokenizer_name,
48 | data_type = self.config.data_type
49 | )
50 |
51 | return datasets
52 |
53 | @registry.register_builder("webvid_instruct")
54 | class WebvidInstruct_Builder(Instruct_Builder):
55 | train_dataset_cls = Video_Instruct_Dataset
56 |
57 | DATASET_CONFIG_DICT = {
58 | "default": "configs/datasets/instruct/webvid_instruct.yaml",
59 | }
60 |
61 | @registry.register_builder("webvid_instruct_zh")
62 | class WebvidInstruct_zh_Builder(Instruct_Builder):
63 | train_dataset_cls = Video_Instruct_Dataset
64 |
65 | DATASET_CONFIG_DICT = {
66 | "default": "configs/datasets/instruct/webvid_instruct.yaml",
67 | }
68 |
69 |
70 |
71 | @registry.register_builder("llava_instruct")
72 | class LlavaInstruct_Builder(Instruct_Builder):
73 | train_dataset_cls = Instruct_Dataset
74 |
75 | DATASET_CONFIG_DICT = {
76 | "default": "configs/datasets/instruct/llava_instruct.yaml",
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/builders/video_caption_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 |
5 | from InfVideoLLaMA.common.registry import registry
6 | from InfVideoLLaMA.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7 | from InfVideoLLaMA.datasets.datasets.webvid_datasets import WebvidDataset
8 |
9 | @registry.register_builder("webvid")
10 | class WebvidBuilder(BaseDatasetBuilder):
11 | train_dataset_cls = WebvidDataset
12 | DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"}
13 |
14 | def _download_ann(self):
15 | pass
16 |
17 | def _download_vis(self):
18 | pass
19 |
20 | def build(self):
21 | self.build_processors()
22 | datasets = dict()
23 | split = "train"
24 |
25 | build_info = self.config.build_info
26 | dataset_cls = self.train_dataset_cls
27 | datasets[split] = dataset_cls(
28 | vis_processor=self.vis_processors[split],
29 | text_processor=self.text_processors[split],
30 | vis_root=build_info.videos_dir,
31 | ann_root=build_info.anno_dir
32 | )
33 |
34 | return datasets
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/__init__.py
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import json
9 | from typing import Iterable
10 |
11 | from torch.utils.data import Dataset, ConcatDataset
12 | from torch.utils.data.dataloader import default_collate
13 |
14 |
15 | class BaseDataset(Dataset):
16 | def __init__(
17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18 | ):
19 | """
20 | vis_root (string): Root directory of images (e.g. coco/images/)
21 | ann_root (string): directory to store the annotation file
22 | """
23 | self.vis_root = vis_root
24 |
25 | self.annotation = []
26 | for ann_path in ann_paths:
27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28 |
29 | self.vis_processor = vis_processor
30 | self.text_processor = text_processor
31 |
32 | self._add_instance_ids()
33 |
34 | def __len__(self):
35 | return len(self.annotation)
36 |
37 | def collater(self, samples):
38 | return default_collate(samples)
39 |
40 | def set_processors(self, vis_processor, text_processor):
41 | self.vis_processor = vis_processor
42 | self.text_processor = text_processor
43 |
44 | def _add_instance_ids(self, key="instance_id"):
45 | for idx, ann in enumerate(self.annotation):
46 | ann[key] = str(idx)
47 |
48 |
49 | class ConcatDataset(ConcatDataset):
50 | def __init__(self, datasets: Iterable[Dataset]) -> None:
51 | super().__init__(datasets)
52 |
53 | def collater(self, samples):
54 | # TODO For now only supports datasets with same underlying collater implementations
55 |
56 | all_keys = set()
57 | for s in samples:
58 | all_keys.update(s)
59 |
60 | shared_keys = all_keys
61 | for s in samples:
62 | shared_keys = shared_keys & set(s.keys())
63 |
64 | samples_shared_keys = []
65 | for s in samples:
66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67 |
68 | return self.datasets[0].collater(samples_shared_keys)
69 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/caption_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | from collections import OrderedDict
10 |
11 | from InfVideoLLaMA.datasets.datasets.base_dataset import BaseDataset
12 | from PIL import Image
13 |
14 |
15 | class __DisplMixin:
16 | def displ_item(self, index):
17 | sample, ann = self.__getitem__(index), self.annotation[index]
18 |
19 | return OrderedDict(
20 | {
21 | "file": ann["image"],
22 | "caption": ann["caption"],
23 | "image": sample["image"],
24 | }
25 | )
26 |
27 |
28 | class CaptionDataset(BaseDataset, __DisplMixin):
29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30 | """
31 | vis_root (string): Root directory of images (e.g. coco/images/)
32 | ann_root (string): directory to store the annotation file
33 | """
34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35 |
36 | self.img_ids = {}
37 | n = 0
38 | for ann in self.annotation:
39 | img_id = ann["image_id"]
40 | if img_id not in self.img_ids.keys():
41 | self.img_ids[img_id] = n
42 | n += 1
43 |
44 | def __getitem__(self, index):
45 |
46 | # TODO this assumes image input, not general enough
47 | ann = self.annotation[index]
48 |
49 | img_file = '{:0>12}.jpg'.format(ann["image_id"])
50 | image_path = os.path.join(self.vis_root, img_file)
51 | image = Image.open(image_path).convert("RGB")
52 |
53 | image = self.vis_processor(image)
54 | caption = self.text_processor(ann["caption"])
55 |
56 | return {
57 | "image": image,
58 | "text_input": caption,
59 | "image_id": self.img_ids[ann["image_id"]],
60 | }
61 |
62 |
63 | class CaptionEvalDataset(BaseDataset, __DisplMixin):
64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65 | """
66 | vis_root (string): Root directory of images (e.g. coco/images/)
67 | ann_root (string): directory to store the annotation file
68 | split (string): val or test
69 | """
70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71 |
72 | def __getitem__(self, index):
73 |
74 | ann = self.annotation[index]
75 |
76 | image_path = os.path.join(self.vis_root, ann["image"])
77 | image = Image.open(image_path).convert("RGB")
78 |
79 | image = self.vis_processor(image)
80 |
81 | return {
82 | "image": image,
83 | "image_id": ann["image_id"],
84 | "instance_id": ann["instance_id"],
85 | }
86 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/cc_sbu_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import webdataset as wds
4 | from InfVideoLLaMA.datasets.datasets.base_dataset import BaseDataset
5 | from InfVideoLLaMA.datasets.datasets.caption_datasets import CaptionDataset
6 |
7 |
8 | class CCSBUDataset(BaseDataset):
9 | def __init__(self, vis_processor, text_processor, location):
10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
11 |
12 | self.inner_dataset = wds.DataPipeline(
13 | wds.ResampledShards(location),
14 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
15 | wds.shuffle(1000, handler=wds.warn_and_continue),
16 | wds.decode("pilrgb", handler=wds.warn_and_continue),
17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
19 | wds.map(self.to_dict, handler=wds.warn_and_continue),
20 | )
21 |
22 | def to_dict(self, sample):
23 | return {
24 | "image": sample[0],
25 | "text_input": self.text_processor(sample[1]["caption"]),
26 | "type":'image',
27 | }
28 |
29 |
30 | class CCSBUAlignDataset(CaptionDataset):
31 |
32 | def __getitem__(self, index):
33 |
34 | # TODO this assumes image input, not general enough
35 | ann = self.annotation[index]
36 |
37 | img_file = '{}.jpg'.format(ann["image_id"])
38 | image_path = os.path.join(self.vis_root, img_file)
39 | image = Image.open(image_path).convert("RGB")
40 |
41 | image = self.vis_processor(image)
42 | caption = ann["caption"]
43 |
44 | return {
45 | "image": image,
46 | "text_input": caption,
47 | "image_id": self.img_ids[ann["image_id"]],
48 | "type":'image',
49 | }
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/dataloader_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import time
9 | import random
10 | import torch
11 | from InfVideoLLaMA.datasets.data_utils import move_to_cuda
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | class MultiIterLoader:
16 | """
17 | A simple wrapper for iterating over multiple iterators.
18 |
19 | Args:
20 | loaders (List[Loader]): List of Iterator loaders.
21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22 | """
23 |
24 | def __init__(self, loaders, ratios=None):
25 | # assert all loaders has __next__ method
26 | for loader in loaders:
27 | assert hasattr(
28 | loader, "__next__"
29 | ), "Loader {} has no __next__ method.".format(loader)
30 |
31 | if ratios is None:
32 | ratios = [1.0] * len(loaders)
33 | else:
34 | assert len(ratios) == len(loaders)
35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36 |
37 | self.loaders = loaders
38 | self.ratios = ratios
39 |
40 | def __next__(self):
41 | # random sample from each loader by ratio
42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43 | return next(self.loaders[loader_idx])
44 |
45 |
46 | class PrefetchLoader(object):
47 | """
48 | Modified from https://github.com/ChenRocks/UNITER.
49 |
50 | overlap compute and cuda data transfer
51 | (copied and then modified from nvidia apex)
52 | """
53 |
54 | def __init__(self, loader):
55 | self.loader = loader
56 | self.stream = torch.cuda.Stream()
57 |
58 | def __iter__(self):
59 | loader_it = iter(self.loader)
60 | self.preload(loader_it)
61 | batch = self.next(loader_it)
62 | while batch is not None:
63 | is_tuple = isinstance(batch, tuple)
64 | if is_tuple:
65 | task, batch = batch
66 |
67 | if is_tuple:
68 | yield task, batch
69 | else:
70 | yield batch
71 | batch = self.next(loader_it)
72 |
73 | def __len__(self):
74 | return len(self.loader)
75 |
76 | def preload(self, it):
77 | try:
78 | self.batch = next(it)
79 | except StopIteration:
80 | self.batch = None
81 | return
82 | # if record_stream() doesn't work, another option is to make sure
83 | # device inputs are created on the main stream.
84 | # self.next_input_gpu = torch.empty_like(self.next_input,
85 | # device='cuda')
86 | # self.next_target_gpu = torch.empty_like(self.next_target,
87 | # device='cuda')
88 | # Need to make sure the memory allocated for next_* is not still in use
89 | # by the main stream at the time we start copying to next_*:
90 | # self.stream.wait_stream(torch.cuda.current_stream())
91 | with torch.cuda.stream(self.stream):
92 | self.batch = move_to_cuda(self.batch)
93 | # more code for the alternative if record_stream() doesn't work:
94 | # copy_ will record the use of the pinned source tensor in this
95 | # side stream.
96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98 | # self.next_input = self.next_input_gpu
99 | # self.next_target = self.next_target_gpu
100 |
101 | def next(self, it):
102 | torch.cuda.current_stream().wait_stream(self.stream)
103 | batch = self.batch
104 | if batch is not None:
105 | record_cuda_stream(batch)
106 | self.preload(it)
107 | return batch
108 |
109 | def __getattr__(self, name):
110 | method = self.loader.__getattribute__(name)
111 | return method
112 |
113 |
114 | def record_cuda_stream(batch):
115 | if isinstance(batch, torch.Tensor):
116 | batch.record_stream(torch.cuda.current_stream())
117 | elif isinstance(batch, list) or isinstance(batch, tuple):
118 | for t in batch:
119 | record_cuda_stream(t)
120 | elif isinstance(batch, dict):
121 | for t in batch.values():
122 | record_cuda_stream(t)
123 | else:
124 | pass
125 |
126 |
127 | class IterLoader:
128 | """
129 | A wrapper to convert DataLoader as an infinite iterator.
130 |
131 | Modified from:
132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133 | """
134 |
135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136 | self._dataloader = dataloader
137 | self.iter_loader = iter(self._dataloader)
138 | self._use_distributed = use_distributed
139 | self._epoch = 0
140 |
141 | @property
142 | def epoch(self) -> int:
143 | return self._epoch
144 |
145 | def __next__(self):
146 | try:
147 | data = next(self.iter_loader)
148 | except StopIteration:
149 | self._epoch += 1
150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151 | self._dataloader.sampler.set_epoch(self._epoch)
152 | time.sleep(2) # Prevent possible deadlock during epoch transition
153 | self.iter_loader = iter(self._dataloader)
154 | data = next(self.iter_loader)
155 |
156 | return data
157 |
158 | def __iter__(self):
159 | return self
160 |
161 | def __len__(self):
162 | return len(self._dataloader)
163 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/laion_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import webdataset as wds
9 | from InfVideoLLaMA.datasets.datasets.base_dataset import BaseDataset
10 |
11 |
12 | class LaionDataset(BaseDataset):
13 | def __init__(self, vis_processor, text_processor, location):
14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15 |
16 | self.inner_dataset = wds.DataPipeline(
17 | wds.ResampledShards(location),
18 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
19 | wds.shuffle(1000, handler=wds.warn_and_continue),
20 | wds.decode("pilrgb", handler=wds.warn_and_continue),
21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23 | wds.map(self.to_dict, handler=wds.warn_and_continue),
24 | )
25 |
26 | def to_dict(self, sample):
27 | return {
28 | "image": sample[0],
29 | "text_input": self.text_processor(sample[1]["caption"]),
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/datasets/datasets/webvid_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | from InfVideoLLaMA.datasets.datasets.base_dataset import BaseDataset
10 | from InfVideoLLaMA.datasets.datasets.caption_datasets import CaptionDataset
11 | import pandas as pd
12 | import decord
13 | from decord import VideoReader
14 | import random
15 | import torch
16 | from torch.utils.data.dataloader import default_collate
17 | class WebvidDataset(BaseDataset):
18 | def __init__(self, vis_processor, text_processor, vis_root, ann_root):
19 | """
20 | vis_root (string): Root directory of video (e.g. webvid_eval/video/)
21 | ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
22 | split (string): val or test
23 | """
24 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
25 |
26 |
27 | # 读取一个路径下所有的
28 |
29 | ts_df = []
30 | for file_name in os.listdir(ann_root):
31 | if file_name.endswith('.csv'):
32 | df = pd.read_csv(os.path.join(ann_root, file_name))
33 | ts_df.append(df)
34 |
35 | merged_df = pd.concat(ts_df)
36 | self.annotation = merged_df
37 | self.vis_root = vis_root
38 | self.resize_size = 224
39 | self.num_frm = 8
40 | self.frm_sampling_strategy = 'headtail'
41 |
42 | def _get_video_path(self, sample):
43 | rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
44 | full_video_fp = os.path.join(self.vis_root, rel_video_fp)
45 | return full_video_fp
46 |
47 | def __getitem__(self, index):
48 | num_retries = 10 # skip error videos
49 | for _ in range(num_retries):
50 | sample = self.annotation.iloc[index]
51 | sample_dict = sample.to_dict()
52 | video_id = sample_dict['videoid']
53 |
54 | if 'name' in sample_dict.keys():
55 | text = sample_dict['name'].strip()
56 | else:
57 | raise NotImplementedError("Un-supported text annotation format.")
58 |
59 | # fetch video
60 | video_path = self._get_video_path(sample_dict)
61 | # if os.path.exists(video_path):
62 | try:
63 | video = self.vis_processor(video_path)
64 | except:
65 | print(f"Failed to load examples with video: {video_path}. "
66 | f"Will randomly sample an example as a replacement.")
67 | index = random.randint(0, len(self) - 1)
68 | continue
69 | caption = self.text_processor(text)
70 |
71 | # print(video.size())
72 | if video is None or caption is None \
73 | or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]):
74 | print(f"Failed to load examples with video: {video_path}. "
75 | f"Will randomly sample an example as a replacement.")
76 | index = random.randint(0, len(self) - 1)
77 | continue
78 | else:
79 | break
80 | else:
81 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
82 | # "image_id" is kept to stay compatible with the COCO evaluation format
83 | return {
84 | "image": video,
85 | "text_input": caption,
86 | "type":'video',
87 | }
88 |
89 | def __len__(self):
90 | return len(self.annotation)
91 |
92 | # def collater(self, samples):
93 | # new_result = {}
94 | # new_result['image'] = default_collate( [sample["image"] for sample in samples])
95 | # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples])
96 | # return new_result
97 |
98 | class WebvidDatasetEvalDataset(BaseDataset):
99 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
100 | """
101 | vis_root (string): Root directory of images (e.g. coco/images/)
102 | ann_root (string): directory to store the annotation file
103 | split (string): val or test
104 | """
105 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
106 |
107 | def __getitem__(self, index):
108 |
109 | ann = self.annotation[index]
110 |
111 | vname = ann["video"]
112 | video_path = os.path.join(self.vis_root, vname)
113 |
114 | video = self.vis_processor(video_path)
115 |
116 | return {
117 | "video": video,
118 | "image_id": ann["image_id"],
119 | "instance_id": ann["instance_id"],
120 | }
121 |
122 |
123 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from salesforce@LAVIS Vision-CAIR@MiniGPT-4. Below is the original copyright:
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 | import logging
10 | import torch
11 | from omegaconf import OmegaConf
12 |
13 | from InfVideoLLaMA.common.registry import registry
14 | from InfVideoLLaMA.models.base_model import BaseModel
15 | from InfVideoLLaMA.models.blip2 import Blip2Base
16 | from InfVideoLLaMA.models.infinityqa import InfinityQA
17 | from InfVideoLLaMA.processors.base_processor import BaseProcessor
18 |
19 |
20 | __all__ = [
21 | "load_model",
22 | "BaseModel",
23 | "Blip2Base",
24 | ]
25 |
26 |
27 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
28 | """
29 | Load supported models.
30 |
31 | """
32 |
33 | model = registry.get_model_class(name).from_pretrained(model_type=model_type)
34 |
35 | if checkpoint is not None:
36 | model.load_checkpoint(checkpoint)
37 |
38 | if is_eval:
39 | model.eval()
40 |
41 | if device == "cpu":
42 | model = model.float()
43 |
44 | return model.to(device)
45 |
46 |
47 | def load_preprocess(config):
48 | """
49 | Load preprocessor configs and construct preprocessors.
50 |
51 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
52 |
53 | """
54 |
55 | def _build_proc_from_cfg(cfg):
56 | return (
57 | registry.get_processor_class(cfg.name).from_config(cfg)
58 | if cfg is not None
59 | else BaseProcessor()
60 | )
61 |
62 | vis_processors = dict()
63 | txt_processors = dict()
64 |
65 | vis_proc_cfg = config.get("vis_processor")
66 | txt_proc_cfg = config.get("text_processor")
67 |
68 | if vis_proc_cfg is not None:
69 | vis_train_cfg = vis_proc_cfg.get("train")
70 | vis_eval_cfg = vis_proc_cfg.get("eval")
71 | else:
72 | vis_train_cfg = None
73 | vis_eval_cfg = None
74 |
75 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
76 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
77 |
78 | if txt_proc_cfg is not None:
79 | txt_train_cfg = txt_proc_cfg.get("train")
80 | txt_eval_cfg = txt_proc_cfg.get("eval")
81 | else:
82 | txt_train_cfg = None
83 | txt_eval_cfg = None
84 |
85 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
86 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
87 |
88 | return vis_processors, txt_processors
89 |
90 |
91 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
92 | """
93 | Load model and its related preprocessors.
94 |
95 | """
96 | model_cls = registry.get_model_class(name)
97 |
98 | # load model
99 | model = model_cls.from_pretrained(model_type=model_type)
100 |
101 | if is_eval:
102 | model.eval()
103 |
104 | # load preprocess
105 | cfg = OmegaConf.load(model_cls.default_config_path(model_type))
106 | if cfg is not None:
107 | preprocess_cfg = cfg.preprocess
108 |
109 | vis_processors, txt_processors = load_preprocess(preprocess_cfg)
110 | else:
111 | vis_processors, txt_processors = None, None
112 | logging.info(
113 | f"""No default preprocess for model {name} ({model_type}).
114 | This can happen if the model is not finetuned on downstream datasets,
115 | or it is not intended for direct use without finetuning.
116 | """
117 | )
118 |
119 | if device == "cpu" or device == torch.device("cpu"):
120 | model = model.float()
121 |
122 | return model.to(device), vis_processors, txt_processors
123 |
124 |
125 | class ModelZoo:
126 | """
127 | A utility class to create string representation of available model architectures and types.
128 |
129 | """
130 |
131 | def __init__(self) -> None:
132 | self.model_zoo = {
133 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
134 | for k, v in registry.mapping["model_name_mapping"].items()
135 | }
136 |
137 | def __str__(self) -> str:
138 | return (
139 | "=" * 50
140 | + "\n"
141 | + f"{'Architectures':<30} {'Types'}\n"
142 | + "=" * 50
143 | + "\n"
144 | + "\n".join(
145 | [
146 | f"{name:<30} {', '.join(types)}"
147 | for name, types in self.model_zoo.items()
148 | ]
149 | )
150 | )
151 |
152 | def __iter__(self):
153 | return iter(self.model_zoo.items())
154 |
155 | def __len__(self):
156 | return sum([len(v) for v in self.model_zoo.values()])
157 |
158 |
159 | model_zoo = ModelZoo()
160 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/models/blip2_outputs.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from salesforce@LAVIS. Below is the original copyright:
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 | from dataclasses import dataclass
10 | from typing import Optional
11 |
12 | import torch
13 | from transformers.modeling_outputs import (
14 | ModelOutput,
15 | BaseModelOutputWithPoolingAndCrossAttentions,
16 | CausalLMOutputWithCrossAttentions,
17 | )
18 |
19 |
20 | @dataclass
21 | class BlipSimilarity(ModelOutput):
22 | sim_i2t: torch.FloatTensor = None
23 | sim_t2i: torch.FloatTensor = None
24 |
25 | sim_i2t_m: Optional[torch.FloatTensor] = None
26 | sim_t2i_m: Optional[torch.FloatTensor] = None
27 |
28 | sim_i2t_targets: Optional[torch.FloatTensor] = None
29 | sim_t2i_targets: Optional[torch.FloatTensor] = None
30 |
31 |
32 | @dataclass
33 | class BlipIntermediateOutput(ModelOutput):
34 | """
35 | Data class for intermediate outputs of BLIP models.
36 |
37 | """
38 |
39 | image_embeds: torch.FloatTensor = None
40 | text_embeds: Optional[torch.FloatTensor] = None
41 |
42 | image_embeds_m: Optional[torch.FloatTensor] = None
43 | text_embeds_m: Optional[torch.FloatTensor] = None
44 |
45 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
46 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
47 |
48 | itm_logits: Optional[torch.FloatTensor] = None
49 | itm_labels: Optional[torch.LongTensor] = None
50 |
51 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
52 | decoder_labels: Optional[torch.LongTensor] = None
53 |
54 |
55 | @dataclass
56 | class BlipOutput(ModelOutput):
57 | sims: Optional[BlipSimilarity] = None
58 |
59 | intermediate_output: BlipIntermediateOutput = None
60 |
61 | loss: Optional[torch.FloatTensor] = None
62 |
63 | loss_itc: Optional[torch.FloatTensor] = None
64 |
65 | loss_itm: Optional[torch.FloatTensor] = None
66 |
67 | loss_lm: Optional[torch.FloatTensor] = None
68 |
69 |
70 | @dataclass
71 | class BlipOutputFeatures(ModelOutput):
72 | """
73 | Data class of features from BlipFeatureExtractor.
74 |
75 | """
76 |
77 | image_embeds: Optional[torch.FloatTensor] = None
78 | image_embeds_proj: Optional[torch.FloatTensor] = None
79 |
80 | text_embeds: Optional[torch.FloatTensor] = None
81 | text_embeds_proj: Optional[torch.FloatTensor] = None
82 |
83 | multimodal_embeds: Optional[torch.FloatTensor] = None
84 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/models/helpers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import einops
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | class Normalize(nn.Module):
16 | def __init__(self, dim: int) -> None:
17 | super().__init__()
18 | self.dim = dim
19 |
20 | def forward(self, x):
21 | return torch.nn.functional.normalize(x, dim=self.dim, p=2)
22 |
23 |
24 | class LearnableLogitScaling(nn.Module):
25 | def __init__(
26 | self,
27 | logit_scale_init: float = 1 / 0.07,
28 | learnable: bool = True,
29 | max_logit_scale: float = 100,
30 | ) -> None:
31 | super().__init__()
32 | self.max_logit_scale = max_logit_scale
33 | self.logit_scale_init = logit_scale_init
34 | self.learnable = learnable
35 | log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
36 | if learnable:
37 | self.log_logit_scale = nn.Parameter(log_logit_scale)
38 | else:
39 | self.register_buffer("log_logit_scale", log_logit_scale)
40 |
41 | def forward(self, x):
42 | return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
43 |
44 | def extra_repr(self):
45 | st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," \
46 | f" max_logit_scale={self.max_logit_scale}"
47 | return st
48 |
49 |
50 | class EinOpsRearrange(nn.Module):
51 | def __init__(self, rearrange_expr: str, **kwargs) -> None:
52 | super().__init__()
53 | self.rearrange_expr = rearrange_expr
54 | self.kwargs = kwargs
55 |
56 | def forward(self, x):
57 | assert isinstance(x, torch.Tensor)
58 | return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
59 |
60 |
61 | class VerboseNNModule(nn.Module):
62 | """
63 | Wrapper around nn.Module that prints registered buffers and parameter names.
64 | """
65 |
66 | @staticmethod
67 | def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
68 | st = (
69 | "("
70 | + name
71 | + "): "
72 | + "tensor("
73 | + str(tuple(tensor[1].shape))
74 | + ", requires_grad="
75 | + str(tensor[1].requires_grad)
76 | + ")\n"
77 | )
78 | return st
79 |
80 | def extra_repr(self) -> str:
81 | named_modules = set()
82 | for p in self.named_modules():
83 | named_modules.update([p[0]])
84 | named_modules = list(named_modules)
85 |
86 | string_repr = ""
87 | for p in self.named_parameters():
88 | name = p[0].split(".")[0]
89 | if name not in named_modules:
90 | string_repr += self.get_readable_tensor_repr(name, p)
91 |
92 | for p in self.named_buffers():
93 | name = p[0].split(".")[0]
94 | string_repr += self.get_readable_tensor_repr(name, p)
95 |
96 | return string_repr
97 |
98 |
99 | def cast_if_src_dtype(
100 | tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
101 | ):
102 | updated = False
103 | if tensor.dtype == src_dtype:
104 | tensor = tensor.to(dtype=tgt_dtype)
105 | updated = True
106 | return tensor, updated
107 |
108 |
109 | class QuickGELU(nn.Module):
110 | # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
111 | def forward(self, x: torch.Tensor):
112 | return x * torch.sigmoid(1.702 * x)
113 |
114 |
115 | class SelectElement(nn.Module):
116 | def __init__(self, index) -> None:
117 | super().__init__()
118 | self.index = index
119 |
120 | def forward(self, x):
121 | assert x.ndim >= 3
122 | return x[:, self.index, ...]
123 |
124 |
125 | class SelectEOSAndProject(nn.Module):
126 | """
127 | Text Pooling used in OpenCLIP
128 | """
129 |
130 | def __init__(self, proj: nn.Module) -> None:
131 | super().__init__()
132 | self.proj = proj
133 |
134 | def forward(self, x, seq_len):
135 | assert x.ndim == 3
136 | # x is of shape B x L x D
137 | # take features from the eot embedding (eot_token is the highest number in each sequence)
138 | x = x[torch.arange(x.shape[0]), seq_len]
139 | x = self.proj(x)
140 | return x
141 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/processors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from InfVideoLLaMA.processors.base_processor import BaseProcessor
9 | from InfVideoLLaMA.processors.blip_processors import (
10 | Blip2ImageTrainProcessor,
11 | Blip2ImageEvalProcessor,
12 | BlipCaptionProcessor,
13 | )
14 | from InfVideoLLaMA.processors.video_processor import (
15 | AlproVideoTrainProcessor,
16 | AlproVideoEvalProcessor
17 | )
18 | from InfVideoLLaMA.common.registry import registry
19 |
20 | __all__ = [
21 | "BaseProcessor",
22 | "Blip2ImageTrainProcessor",
23 | "Blip2ImageEvalProcessor",
24 | "BlipCaptionProcessor",
25 | "AlproVideoTrainProcessor",
26 | "AlproVideoEvalProcessor",
27 | ]
28 |
29 |
30 | def load_processor(name, cfg=None):
31 | """
32 | Example
33 |
34 | >>> processor = load_processor("alpro_video_train", cfg=None)
35 | """
36 | processor = registry.get_processor_class(name).from_config(cfg)
37 |
38 | return processor
39 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/processors/base_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from omegaconf import OmegaConf
9 |
10 |
11 | class BaseProcessor:
12 | def __init__(self):
13 | self.transform = lambda x: x
14 | return
15 |
16 | def __call__(self, item):
17 | return self.transform(item)
18 |
19 | @classmethod
20 | def from_config(cls, cfg=None):
21 | return cls()
22 |
23 | def build(self, **kwargs):
24 | cfg = OmegaConf.create(kwargs)
25 |
26 | return self.from_config(cfg)
27 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/processors/blip_processors.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import re
9 |
10 | from InfVideoLLaMA.common.registry import registry
11 | from InfVideoLLaMA.processors.base_processor import BaseProcessor
12 | from InfVideoLLaMA.processors.randaugment import RandomAugment
13 | from omegaconf import OmegaConf
14 | from torchvision import transforms
15 | from torchvision.transforms.functional import InterpolationMode
16 |
17 |
18 | class BlipImageBaseProcessor(BaseProcessor):
19 | def __init__(self, mean=None, std=None):
20 | if mean is None:
21 | mean = (0.48145466, 0.4578275, 0.40821073)
22 | if std is None:
23 | std = (0.26862954, 0.26130258, 0.27577711)
24 |
25 | self.normalize = transforms.Normalize(mean, std)
26 |
27 |
28 | @registry.register_processor("blip_caption")
29 | class BlipCaptionProcessor(BaseProcessor):
30 | def __init__(self, prompt="", max_words=50):
31 | self.prompt = prompt
32 | self.max_words = max_words
33 |
34 | def __call__(self, caption):
35 | caption = self.prompt + self.pre_caption(caption)
36 |
37 | return caption
38 |
39 | @classmethod
40 | def from_config(cls, cfg=None):
41 | if cfg is None:
42 | cfg = OmegaConf.create()
43 |
44 | prompt = cfg.get("prompt", "")
45 | max_words = cfg.get("max_words", 50)
46 |
47 | return cls(prompt=prompt, max_words=max_words)
48 |
49 | def pre_caption(self, caption):
50 | caption = re.sub(
51 | r"([.!\"()*#:;~])",
52 | " ",
53 | caption.lower(),
54 | )
55 | caption = re.sub(
56 | r"\s{2,}",
57 | " ",
58 | caption,
59 | )
60 | caption = caption.rstrip("\n")
61 | caption = caption.strip(" ")
62 |
63 | # truncate caption
64 | caption_words = caption.split(" ")
65 | if len(caption_words) > self.max_words:
66 | caption = " ".join(caption_words[: self.max_words])
67 |
68 | return caption
69 |
70 |
71 | @registry.register_processor("blip2_image_train")
72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
74 | super().__init__(mean=mean, std=std)
75 |
76 | self.transform = transforms.Compose(
77 | [
78 | transforms.RandomResizedCrop(
79 | image_size,
80 | scale=(min_scale, max_scale),
81 | interpolation=InterpolationMode.BICUBIC,
82 | ),
83 | transforms.ToTensor(),
84 | self.normalize,
85 | ]
86 | )
87 |
88 | def __call__(self, item):
89 | return self.transform(item)
90 |
91 | @classmethod
92 | def from_config(cls, cfg=None):
93 | if cfg is None:
94 | cfg = OmegaConf.create()
95 |
96 | image_size = cfg.get("image_size", 224)
97 |
98 | mean = cfg.get("mean", None)
99 | std = cfg.get("std", None)
100 |
101 | min_scale = cfg.get("min_scale", 0.5)
102 | max_scale = cfg.get("max_scale", 1.0)
103 |
104 | return cls(
105 | image_size=image_size,
106 | mean=mean,
107 | std=std,
108 | min_scale=min_scale,
109 | max_scale=max_scale,
110 | )
111 |
112 |
113 | @registry.register_processor("blip2_image_eval")
114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
115 | def __init__(self, image_size=224, mean=None, std=None):
116 | super().__init__(mean=mean, std=std)
117 |
118 | self.transform = transforms.Compose(
119 | [
120 | transforms.Resize(
121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
122 | ),
123 | transforms.ToTensor(),
124 | self.normalize,
125 | ]
126 | )
127 |
128 | def __call__(self, item):
129 | return self.transform(item)
130 |
131 | @classmethod
132 | def from_config(cls, cfg=None):
133 | if cfg is None:
134 | cfg = OmegaConf.create()
135 |
136 | image_size = cfg.get("image_size", 224)
137 |
138 | mean = cfg.get("mean", None)
139 | std = cfg.get("std", None)
140 |
141 | return cls(image_size=image_size, mean=mean, std=std)
142 |
143 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/processors/functional_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import warnings
9 |
10 | import torch
11 |
12 |
13 | def _is_tensor_video_clip(clip):
14 | if not torch.is_tensor(clip):
15 | raise TypeError("clip should be Tensor. Got %s" % type(clip))
16 |
17 | if not clip.ndimension() == 4:
18 | raise ValueError("clip should be 4D. Got %dD" % clip.dim())
19 |
20 | return True
21 |
22 |
23 | def crop(clip, i, j, h, w):
24 | """
25 | Args:
26 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
27 | """
28 | if len(clip.size()) != 4:
29 | raise ValueError("clip should be a 4D tensor")
30 | return clip[..., i : i + h, j : j + w]
31 |
32 |
33 | def resize(clip, target_size, interpolation_mode):
34 | if len(target_size) != 2:
35 | raise ValueError(
36 | f"target size should be tuple (height, width), instead got {target_size}"
37 | )
38 | return torch.nn.functional.interpolate(
39 | clip, size=target_size, mode=interpolation_mode, align_corners=False
40 | )
41 |
42 |
43 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
44 | """
45 | Do spatial cropping and resizing to the video clip
46 | Args:
47 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
48 | i (int): i in (i,j) i.e coordinates of the upper left corner.
49 | j (int): j in (i,j) i.e coordinates of the upper left corner.
50 | h (int): Height of the cropped region.
51 | w (int): Width of the cropped region.
52 | size (tuple(int, int)): height and width of resized clip
53 | Returns:
54 | clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
55 | """
56 | if not _is_tensor_video_clip(clip):
57 | raise ValueError("clip should be a 4D torch.tensor")
58 | clip = crop(clip, i, j, h, w)
59 | clip = resize(clip, size, interpolation_mode)
60 | return clip
61 |
62 |
63 | def center_crop(clip, crop_size):
64 | if not _is_tensor_video_clip(clip):
65 | raise ValueError("clip should be a 4D torch.tensor")
66 | h, w = clip.size(-2), clip.size(-1)
67 | th, tw = crop_size
68 | if h < th or w < tw:
69 | raise ValueError("height and width must be no smaller than crop_size")
70 |
71 | i = int(round((h - th) / 2.0))
72 | j = int(round((w - tw) / 2.0))
73 | return crop(clip, i, j, th, tw)
74 |
75 |
76 | def to_tensor(clip):
77 | """
78 | Convert tensor data type from uint8 to float, divide value by 255.0 and
79 | permute the dimensions of clip tensor
80 | Args:
81 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
82 | Return:
83 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
84 | """
85 | _is_tensor_video_clip(clip)
86 | if not clip.dtype == torch.uint8:
87 | raise TypeError(
88 | "clip tensor should have data type uint8. Got %s" % str(clip.dtype)
89 | )
90 | return clip.float().permute(3, 0, 1, 2) / 255.0
91 |
92 |
93 | def normalize(clip, mean, std, inplace=False):
94 | """
95 | Args:
96 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
97 | mean (tuple): pixel RGB mean. Size is (3)
98 | std (tuple): pixel standard deviation. Size is (3)
99 | Returns:
100 | normalized clip (torch.tensor): Size is (C, T, H, W)
101 | """
102 | if not _is_tensor_video_clip(clip):
103 | raise ValueError("clip should be a 4D torch.tensor")
104 | if not inplace:
105 | clip = clip.clone()
106 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
107 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
108 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
109 | return clip
110 |
111 |
112 | def hflip(clip):
113 | """
114 | Args:
115 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
116 | Returns:
117 | flipped clip (torch.tensor): Size is (C, T, H, W)
118 | """
119 | if not _is_tensor_video_clip(clip):
120 | raise ValueError("clip should be a 4D torch.tensor")
121 | return clip.flip(-1)
122 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/processors/transforms_video.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 |
10 | import numbers
11 | import random
12 |
13 | from torchvision.transforms import (
14 | RandomCrop,
15 | RandomResizedCrop,
16 | )
17 |
18 | import InfVideoLLaMA.processors.functional_video as F
19 |
20 |
21 | __all__ = [
22 | "RandomCropVideo",
23 | "RandomResizedCropVideo",
24 | "CenterCropVideo",
25 | "NormalizeVideo",
26 | "ToTensorVideo",
27 | "RandomHorizontalFlipVideo",
28 | ]
29 |
30 |
31 | class RandomCropVideo(RandomCrop):
32 | def __init__(self, size):
33 | if isinstance(size, numbers.Number):
34 | self.size = (int(size), int(size))
35 | else:
36 | self.size = size
37 |
38 | def __call__(self, clip):
39 | """
40 | Args:
41 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
42 | Returns:
43 | torch.tensor: randomly cropped/resized video clip.
44 | size is (C, T, OH, OW)
45 | """
46 | i, j, h, w = self.get_params(clip, self.size)
47 | return F.crop(clip, i, j, h, w)
48 |
49 | def __repr__(self) -> str:
50 | return f"{self.__class__.__name__}(size={self.size})"
51 |
52 |
53 | class RandomResizedCropVideo(RandomResizedCrop):
54 | def __init__(
55 | self,
56 | size,
57 | scale=(0.08, 1.0),
58 | ratio=(3.0 / 4.0, 4.0 / 3.0),
59 | interpolation_mode="bilinear",
60 | ):
61 | if isinstance(size, tuple):
62 | if len(size) != 2:
63 | raise ValueError(
64 | f"size should be tuple (height, width), instead got {size}"
65 | )
66 | self.size = size
67 | else:
68 | self.size = (size, size)
69 |
70 | self.interpolation_mode = interpolation_mode
71 | self.scale = scale
72 | self.ratio = ratio
73 |
74 | def __call__(self, clip):
75 | """
76 | Args:
77 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
78 | Returns:
79 | torch.tensor: randomly cropped/resized video clip.
80 | size is (C, T, H, W)
81 | """
82 | i, j, h, w = self.get_params(clip, self.scale, self.ratio)
83 | return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
84 |
85 | def __repr__(self) -> str:
86 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
87 |
88 |
89 | class CenterCropVideo:
90 | def __init__(self, crop_size):
91 | if isinstance(crop_size, numbers.Number):
92 | self.crop_size = (int(crop_size), int(crop_size))
93 | else:
94 | self.crop_size = crop_size
95 |
96 | def __call__(self, clip):
97 | """
98 | Args:
99 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
100 | Returns:
101 | torch.tensor: central cropping of video clip. Size is
102 | (C, T, crop_size, crop_size)
103 | """
104 | return F.center_crop(clip, self.crop_size)
105 |
106 | def __repr__(self) -> str:
107 | return f"{self.__class__.__name__}(crop_size={self.crop_size})"
108 |
109 |
110 | class NormalizeVideo:
111 | """
112 | Normalize the video clip by mean subtraction and division by standard deviation
113 | Args:
114 | mean (3-tuple): pixel RGB mean
115 | std (3-tuple): pixel RGB standard deviation
116 | inplace (boolean): whether do in-place normalization
117 | """
118 |
119 | def __init__(self, mean, std, inplace=False):
120 | self.mean = mean
121 | self.std = std
122 | self.inplace = inplace
123 |
124 | def __call__(self, clip):
125 | """
126 | Args:
127 | clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
128 | """
129 | return F.normalize(clip, self.mean, self.std, self.inplace)
130 |
131 | def __repr__(self) -> str:
132 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
133 |
134 |
135 | class ToTensorVideo:
136 | """
137 | Convert tensor data type from uint8 to float, divide value by 255.0 and
138 | permute the dimensions of clip tensor
139 | """
140 |
141 | def __init__(self):
142 | pass
143 |
144 | def __call__(self, clip):
145 | """
146 | Args:
147 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
148 | Return:
149 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
150 | """
151 | return F.to_tensor(clip)
152 |
153 | def __repr__(self) -> str:
154 | return self.__class__.__name__
155 |
156 |
157 | class RandomHorizontalFlipVideo:
158 | """
159 | Flip the video clip along the horizonal direction with a given probability
160 | Args:
161 | p (float): probability of the clip being flipped. Default value is 0.5
162 | """
163 |
164 | def __init__(self, p=0.5):
165 | self.p = p
166 |
167 | def __call__(self, clip):
168 | """
169 | Args:
170 | clip (torch.tensor): Size is (C, T, H, W)
171 | Return:
172 | clip (torch.tensor): Size is (C, T, H, W)
173 | """
174 | if random.random() < self.p:
175 | clip = F.hflip(clip)
176 | return clip
177 |
178 | def __repr__(self) -> str:
179 | return f"{self.__class__.__name__}(p={self.p})"
180 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/runners/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from InfVideoLLaMA.runners.runner_base import RunnerBase
9 |
10 | __all__ = ["RunnerBase"]
11 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/runners/test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-Video-LLaMA/InfVideoLLaMA/runners/test.py
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from InfVideoLLaMA.common.registry import registry
9 | from InfVideoLLaMA.tasks.base_task import BaseTask
10 | from InfVideoLLaMA.tasks.image_text_pretrain import ImageTextPretrainTask
11 | from InfVideoLLaMA.tasks.video_text_pretrain import VideoTextPretrainTask
12 |
13 |
14 | def setup_task(cfg):
15 | assert "task" in cfg.run_cfg, "Task name must be provided."
16 |
17 | task_name = cfg.run_cfg.task
18 | task = registry.get_task_class(task_name).setup_task(cfg=cfg)
19 | assert task is not None, "Task {} not properly registered.".format(task_name)
20 |
21 | return task
22 |
23 |
24 | __all__ = [
25 | "BaseTask",
26 | "ImageTextPretrainTask",
27 | "VideoTextPretrainTask"
28 | ]
29 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/tasks/image_text_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from InfVideoLLaMA.common.registry import registry
9 | from InfVideoLLaMA.tasks.base_task import BaseTask
10 |
11 |
12 | @registry.register_task("image_text_pretrain")
13 | class ImageTextPretrainTask(BaseTask):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def evaluation(self, model, data_loader, cuda_enabled=True):
18 | pass
19 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/InfVideoLLaMA/tasks/video_text_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from InfVideoLLaMA.common.registry import registry
9 | from InfVideoLLaMA.tasks.base_task import BaseTask
10 |
11 |
12 | @registry.register_task("video_text_pretrain")
13 | class VideoTextPretrainTask(BaseTask):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def evaluation(self, model, data_loader, cuda_enabled=True):
18 | pass
19 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Apply the delta weights on top of a base model.
3 | Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/model/apply_delta.py.
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 |
11 |
12 | def apply_delta(base_model_path, target_model_path, delta_path):
13 | print(f"Loading the base model from {base_model_path}")
14 | base = AutoModelForCausalLM.from_pretrained(
15 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
16 |
17 | print(f"Loading the delta from {delta_path}")
18 | delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
19 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
20 |
21 | DEFAULT_PAD_TOKEN = "[PAD]"
22 | base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
23 | num_new_tokens = base_tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
24 |
25 | base.resize_token_embeddings(len(base_tokenizer))
26 | input_embeddings = base.get_input_embeddings().weight.data
27 | output_embeddings = base.get_output_embeddings().weight.data
28 | input_embeddings[-num_new_tokens:] = 0
29 | output_embeddings[-num_new_tokens:] = 0
30 |
31 | print("Applying the delta")
32 | for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
33 | assert name in delta.state_dict()
34 | param.data += delta.state_dict()[name]
35 |
36 | print(f"Saving the target model to {target_model_path}")
37 | base.save_pretrained(target_model_path)
38 | delta_tokenizer.save_pretrained(target_model_path)
39 |
40 |
41 | if __name__ == "__main__":
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--base-model-path", type=str, required=True)
44 | parser.add_argument("--target-model-path", type=str, required=True)
45 | parser.add_argument("--delta-path", type=str, required=True)
46 |
47 | args = parser.parse_args()
48 |
49 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
50 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/eval_code/eval/extract_features.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import cv2
4 | from tqdm import tqdm
5 | import argparse
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
9 | parser.add_argument("--input_path", required=True, help="videos path")
10 | parser.add_argument("--output_path", required=True, type=str, help="path to save features.")
11 | parser.add_argument("--num_frames", required=True, type=int, help="number of frames to sample from each video.")
12 |
13 | args = parser.parse_args()
14 | return args
15 |
16 | def extract_and_save_features(input_path, output_path, num_frames):
17 | input_base_path = Path(input_path)
18 | output_base_path = Path(output_path)
19 |
20 | # Placeholder for answers if needed later
21 | answers = {}
22 |
23 | pbar = tqdm(total=len(list(input_base_path.iterdir())))
24 | for video_fp in list(input_base_path.iterdir()):
25 | if video_fp.stem not in [p.stem for p in output_base_path.iterdir()]:
26 | output_path = output_base_path / video_fp.stem
27 | output_path.mkdir(parents=True, exist_ok=True)
28 |
29 | # Use OpenCV to read video frames efficiently
30 | video_capture = cv2.VideoCapture(str(str(video_fp)))
31 | total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
32 |
33 | # Uniformly sample frame indices
34 | frame_indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
35 |
36 | video_frames = []
37 | for frame_idx in frame_indices:
38 | video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
39 | success, frame = video_capture.read()
40 | if success:
41 | frame = cv2.resize(frame, (224, 224))
42 | video_frames.append(frame)
43 | video_capture.release()
44 |
45 | # Save features as images
46 | save_image_frames(video_frames, video_fp.stem, output_path)
47 | pbar.update(1)
48 |
49 | pbar.close()
50 |
51 | def save_image_frames(video_frames, name_ids, save_folder):
52 | """
53 | Save video frames as image files in a specified folder.
54 |
55 | Args:
56 | - video_frames (list): List containing video frames
57 | - name_ids (str): Identifier to include in the filename
58 | - save_folder (str): Path to the folder where the images should be saved
59 |
60 | Returns:
61 | - None
62 | """
63 | for idx, frame in enumerate(video_frames):
64 | filename = f"{name_ids}_frame_{idx:04d}.jpg" # Construct filename with frame index
65 | filepath = os.path.join(save_folder, filename)
66 | cv2.imwrite(filepath, frame) # Save frame as image
67 |
68 | if __name__ == '__main__':
69 | args = parse_args()
70 | extract_and_save_features(args.input_path, args.output_path, args.num_frames)
--------------------------------------------------------------------------------
/infty-Video-LLaMA/eval_code/eval/utils.py:
--------------------------------------------------------------------------------
1 | from moviepy.editor import*
2 | from decord import VideoReader
3 | import decord
4 | import torch
5 | import random as rnd
6 | import numpy as np
7 |
8 | def video_duration(filename):
9 | with VideoFileClip(filename) as video:
10 | fps = video.fps # frames per second
11 |
12 | # Calculate the total number of frames
13 | total_frames = int(video.duration * fps)
14 | return video.duration, total_frames
15 |
16 | def capture_video(video_path, fragment_video_path, per_video_length, n_stage):
17 | start_time = n_stage * per_video_length
18 | end_time = (n_stage+1) * per_video_length
19 | video =CompositeVideoClip([VideoFileClip(video_path).subclip(start_time,end_time)])
20 |
21 | video.write_videofile(fragment_video_path)
22 |
23 |
24 | def load_video(video_path, n_frms=16, height=-1, width=-1, sampling="uniform", return_msg = False):
25 | decord.bridge.set_bridge("torch")
26 | vr = VideoReader(uri=video_path, height=height, width=width)
27 |
28 | vlen = len(vr)
29 | start, end = 0, vlen
30 | n_frms = min(n_frms, vlen)
31 | if sampling == "uniform":
32 | indices = np.linspace(start, end - 1, n_frms).astype(int).tolist()
33 | elif sampling == "headtail":
34 | indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
35 | indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
36 | indices = indices_h + indices_t
37 | else:
38 | raise NotImplementedError
39 |
40 | # get_batch -> T, H, W, C
41 | temp_frms = vr.get_batch(indices)
42 | tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
43 | frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W)
44 |
45 | if not return_msg:
46 | return frms
47 |
48 | fps = float(vr.get_avg_fps())
49 | sec = ", ".join([str(round(f / fps, 1)) for f in indices])
50 | # " " should be added in the start and end
51 | msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. "
52 | return frms, msg
53 |
54 |
55 | def parse_video_fragment(video_path, video_length, fragment_video_path, n_stage = 0, n_samples = 1):
56 | decord.bridge.set_bridge("torch")
57 | per_video_length = video_length / n_samples
58 | # cut video from per_video_length(n_stage-1, n_stage)
59 | capture_video(video_path, fragment_video_path, per_video_length, n_stage)
60 | return fragment_video_path
--------------------------------------------------------------------------------
/infty-Video-LLaMA/eval_code/validate/egoschema_acc.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import requests
3 | import json
4 |
5 | def send_post_request(json_file):
6 | """
7 | Sends a POST request to the specified URL with the given JSON file.
8 |
9 | Parameters:
10 | - json_file (str): Path to the JSON file to be used in the request body.
11 |
12 | Returns:
13 | - Response object containing server's response.
14 | """
15 |
16 | url = "https://validation-server.onrender.com/api/upload/"
17 | headers = {
18 | "Content-Type": "application/json"
19 | }
20 |
21 | with open(json_file, 'r') as f:
22 | data = json.load(f)
23 |
24 | response = requests.post(url, headers=headers, json=data)
25 |
26 | return response
27 |
28 | def main():
29 | """
30 | Main function that parses command-line arguments and sends a POST request.
31 | """
32 |
33 | parser = argparse.ArgumentParser(description="Send a POST request with a JSON file.")
34 | parser.add_argument("--f", required=True, help="Path to the JSON file to be sent with the request.")
35 |
36 | args = parser.parse_args()
37 |
38 | response = send_post_request(args.f)
39 | print(f"Response Status Code: {response.status_code}")
40 | print(f"Response Content:\n{response.text}")
41 |
42 | if __name__ == "__main__":
43 | main()
--------------------------------------------------------------------------------
/infty-Video-LLaMA/eval_code/validate/run_eval_langchain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import ast
5 | from multiprocessing.pool import Pool
6 | import sys
7 | import os
8 | from openai import OpenAI
9 | import shutil
10 | import random
11 | from utils import *
12 | from multiprocessing import Manager
13 | from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
14 | from langchain.vectorstores import Chroma
15 | from langchain.embeddings import OpenAIEmbeddings
16 | embedding = OpenAIEmbeddings(openai_api_key="")
17 |
18 |
19 | manager = Manager()
20 | results = manager.dict()
21 |
22 | option_str = {"0": "A",
23 | "1": "B",
24 | "2": "C",
25 | "3": "D",
26 | "4": "E",
27 | }
28 |
29 | current_dir = os.getcwd()
30 | sys.path.append(current_dir)
31 |
32 | def parse_args():
33 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
34 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
35 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
36 | args = parser.parse_args()
37 | return args
38 |
39 | def annotate(prediction_set, keys_chunk):
40 | """
41 | Evaluates question and answer pairs using GPT-3
42 | Returns a score for correctness.
43 | """
44 | for i, key in enumerate(keys_chunk):
45 | id = key
46 | key = prediction_set[key]
47 | # Define the five prediction options
48 | predictions = [
49 | {"prediction": key["options"][0]},
50 | {"prediction": key["options"][1]},
51 | {"prediction": key["options"][2]},
52 | {"prediction": key["options"][3]},
53 | #{"prediction": key["options"][4]},
54 | ]
55 |
56 | example_selector = SemanticSimilarityExampleSelector.from_examples(
57 | examples=predictions,
58 | embeddings=embedding, # Embedding model for similarity
59 | vectorstore_cls=Chroma, # VectorStore for indexing and searching
60 | k=1, # Return the single most similar prediction,
61 | collection_name = "collection" + str(random.randint(1, 1000000000))
62 | )
63 |
64 | # Store the ground truth separately
65 | ground_truth = {"prediction": key["prediction"]}
66 |
67 | # Select the most similar prediction using the ground truth
68 | pred = example_selector.select_examples(ground_truth)
69 | index = key["options"].index(pred[0]["prediction"])
70 | answer = option_str[str(index)]
71 | if "duration" in key:
72 | results[id] = {"prediction": answer,
73 | "answer": key["answer"],
74 | "duration": key["duration"]}
75 | else:
76 | results[id] = {"prediction": answer,
77 | "answer": key["answer"]}
78 |
79 | def main(args):
80 | """
81 | Main function to control the flow of the program.
82 | """
83 |
84 | prediction_set = load_json(args.pred_path)
85 | num_tasks = args.num_tasks
86 | # import pdb; pdb.set_trace()
87 | ids = list(prediction_set.keys())
88 |
89 | # Split tasks into parts.
90 | part_len = len(ids) // num_tasks
91 | all_parts = [ids[i:i + part_len] for i in range(0, len(ids), part_len)]
92 | task_args = [(prediction_set, part) for part in all_parts]
93 |
94 | # Use a pool of workers to process the files in parallel.
95 | with Pool() as pool:
96 | pool.starmap(annotate, task_args)
97 |
98 | if __name__ == "__main__":
99 | args = parse_args()
100 | # Get the directory name
101 | directory = os.path.dirname(args.pred_path)
102 |
103 | # Create the new directory path with 'results_' prefix
104 | output_dir = os.path.join(os.path.dirname(directory), "results_" + os.path.basename(directory))
105 | print(output_dir)
106 | args.output_dir= output_dir
107 | # Generate output directory if not exists.
108 | if not os.path.exists(output_dir):
109 | os.makedirs(args.output_dir)
110 | else:
111 | shutil.rmtree(args.output_dir)
112 | # Create the directory again (empty)
113 | os.makedirs(args.output_dir)
114 |
115 | main(args)
116 | save_json(dict(results), output_dir + "/preds.json")
117 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/eval_code/validate/test.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | # Input data (assuming this is the loaded JSON data)
4 | data = [
5 | {"video_id": "001", "duration": "short", "domain": "Knowledge", "sub_category": "Humanity & History",
6 | "url": "https://www.youtube.com/watch?v=fFjv93ACGo8", "videoID": "fFjv93ACGo8", "question_id": "001-1",
7 | "task_type": "Counting Problem", "question": "When demonstrating the Germany modern Christmas tree is initially decorated with apples, candles and berries, which kind of the decoration has the largest number?",
8 | "options": ["A. Apples.", "B. Candles.", "C. Berries.", "D. The three kinds are of the same number."],
9 | "answer": "C"},
10 | {"video_id": "001", "duration": "short", "domain": "Knowledge", "sub_category": "Humanity & History",
11 | "url": "https://www.youtube.com/watch?v=fFjv93ACGo8", "videoID": "fFjv93ACGo8", "question_id": "001-2",
12 | "task_type": "Information Synopsis", "question": "What is the genre of this video?",
13 | "options": ["A. It is a news report that introduces the history behind Christmas decorations.",
14 | "B. It is a documentary on the evolution of Christmas holiday recipes.",
15 | "C. It is a travel vlog exploring Christmas markets around the world.",
16 | "D. It is a tutorial on DIY Christmas ornament crafting."],
17 | "answer": "A"}
18 | ]
19 |
20 | # Transforming the data
21 | result = {}
22 | for entry in data:
23 | video_id = entry["video_id"]
24 |
25 | # Initialize the video entry in the result dictionary if not already
26 | if video_id not in result:
27 | result[video_id] = {
28 | "video_id": video_id,
29 | "duration": entry["duration"],
30 | "domain": entry["domain"],
31 | "sub_category": entry["sub_category"],
32 | "questions": []
33 | }
34 |
35 | # Append the question to the corresponding video entry
36 | question_data = {
37 | "question_id": entry["question_id"],
38 | "task_type": entry["task_type"],
39 | "question": entry["question"],
40 | "options": entry["options"],
41 | "answer": entry["answer"]
42 | }
43 |
44 | result[video_id]["questions"].append(question_data)
45 |
46 | # Converting the result dictionary to the required list format
47 | formatted_result = list(result.values())
48 |
49 | # Output the formatted result
50 | print(json.dumps(formatted_result, indent=4))
--------------------------------------------------------------------------------
/infty-Video-LLaMA/eval_configs/infvideollama.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: infvideollama
3 | #model_type: pretrain_vicuna
4 | model_type: pretrain_vicuna
5 | freeze_vit: True
6 | freeze_qformer: True
7 | max_txt_len: 160
8 | end_sym: "###"
9 | low_resource: False
10 |
11 | frozen_llama_proj: False
12 |
13 | #llama_model: "ckpt/llama2/llama-2-7b-chat-hf"
14 | #llama_model: "media/scratch/shared/models/Llama-2-7b-hf"
15 | #llama_model: "/mnt/data-poseidon/saul/MovieChat/ckpt/Llama-2-7b-chat-hf"
16 | llama_model:: "/ckpt/MovieChat-vicuna"
17 |
18 | llama_proj_model: '/ckpt/pretrained_minigpt4.pth'
19 |
20 | fusion_head_layers: 2
21 | max_frame_pos: 32
22 | fusion_header_type: "seqTransf"
23 |
24 | # ckpt: "ckpt/VL_LLaMA_2_7B_Finetuned.pth"
25 | #ckpt: "/mnt/data-poseidon/saul/MovieChat/ckpt/Video-LLaMA-2-7B-Finetuned/VL_LLaMA_2_7B_Finetuned.pth"
26 | ckpt: "/ckpt/finetune-vicuna7b-v2.pth"
27 |
28 | datasets:
29 | webvid:
30 | vis_processor:
31 | train:
32 | name: "alpro_video_eval"
33 | n_frms: 8
34 | image_size: 224
35 | text_processor:
36 | train:
37 | name: "blip_caption"
38 |
39 | run:
40 | task: video_text_pretrain
41 |
--------------------------------------------------------------------------------
/infty-Video-LLaMA/relevant_frames.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 | import os
6 | import torch
7 | import shutil
8 | import numpy as np
9 | from PIL import Image
10 |
11 | with open('./alphas_uniform', 'rb') as f:
12 | density_tensor = pickle.load(f)
13 |
14 | # Assuming density_tensor_uniform and density_tensor_sticky are loaded and contain the required data
15 | # Convert the loaded lists to NumPy arrays if necessary
16 | density_tensor_sticky = np.array(density_tensor.cpu())
17 |
18 | density_sticky = np.mean(density_tensor_sticky, axis=(0, 1, 2))
19 |
20 | density_sticky = density_sticky/np.sum(density_sticky)
21 | # Define chunk sizes
22 | chunk_size = 256
23 | chunks = [range(i, i + chunk_size) for i in range(0, 768, chunk_size)] # Chunk indices
24 |
25 | # Plotting: create 2 rows of 3 plots for uniform and sticky
26 | fig, axs = plt.subplots(1, 3, figsize=(12, 1.5), constrained_layout=True)
27 |
28 | # Loop over each chunk and plot both uniform and sticky
29 | for i, chunk in enumerate(chunks):
30 | # Extract corresponding ranges from the uniform and sticky density arrays
31 | sticky_chunk = density_sticky[chunk]
32 | # Top row: uniform density
33 | sns.heatmap(sticky_chunk.reshape(1, -1), cmap="viridis", cbar=True,
34 | ax=axs[i], square=False, yticklabels=False, cbar_kws={'orientation': 'vertical'})
35 |
36 | # Set xticks to match the chunk range
37 | xtick_positions = np.linspace(0, 256, 6) # Adjust this depending on the number of ticks you want
38 | xtick_labels = np.round(np.linspace(chunk.start, chunk.stop, 6), 0).astype(int) # Labels for the chunk range
39 | axs[i].set_xticks(xtick_positions)
40 | axs[i].set_xticklabels(xtick_labels, fontsize=10, rotation=0)
41 | axs[i].set_xlabel("# Frames", fontsize=10)
42 |
43 |
44 | # Save and display the figure
45 | output_path = "chunks.pdf"
46 | plt.savefig(output_path, dpi=300, bbox_inches='tight')
47 | plt.show()
48 |
49 | import torch
50 | video_list = torch.load("your_video")
51 | k = 10
52 | frames_dir = "frames_uniform"
53 | for i, chunk in enumerate(chunks):
54 | # Extract corresponding ranges from the uniform and sticky density arrays
55 | sticky_chunk = density_sticky[chunk]
56 |
57 | # Get the top-k indices for the sticky density (descending order)
58 | top_k_sticky_indices = np.argsort(sticky_chunk)[-k:][::-1]
59 |
60 | # Print the indices for each chunk
61 | print(f"Chunk {i + 1}: {chunk.start} to {chunk.stop - 1}")
62 | print(f"Top {k} sticky density indices: {top_k_sticky_indices}")
63 | print("-" * 50)
64 |
65 | # Process the top-k uniform density frames
66 | for idx in top_k_sticky_indices:
67 | try:
68 | # Retrieve the video tensor for the corresponding index (assuming video_list is available)
69 | video_tensor = video_list[:, idx + 256*i] # Retrieve the image path from the video list
70 |
71 | # Convert the tensor to a numpy array (HWC format)
72 | image_np = video_tensor.permute(1,2, 0).cpu().numpy() # Change shape to HWC (Height, Width, Channels)
73 |
74 | # Normalize values if necessary (assuming the tensor values are between 0 and 1)
75 | image = Image.fromarray(image_np.astype(np.uint8))
76 |
77 |
78 | # Construct the filename for saving
79 | filename = os.path.join(frames_dir, f"frame_{i + 1}_{idx + 256*i}.png")
80 |
81 | # Save the image as PNG
82 | image.save(filename)
83 | except Exception as e:
84 | print(f"Failed to load or save uniform image for index {idx}: {e}")
--------------------------------------------------------------------------------
/infty-VideoChat2/configs/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "model_cls": "VideoChat2_it",
4 | "vit_blip_model_path": "your_model_path/umt_l16_qformer.pth",
5 | "llama_model_path": "your_model_path/vicuna-7b-v0",
6 | "videochat2_model_path": "your_model_path/videochat2_7b_stage2.pth",
7 | "freeze_vit": false,
8 | "freeze_qformer": false,
9 | "max_txt_len": 512,
10 | "low_resource": false,
11 | "vision_encoder": {
12 | "name": "vit_l14",
13 | "img_size": 224,
14 | "patch_size": 16,
15 | "d_model": 1024,
16 | "encoder_embed_dim": 1024,
17 | "encoder_depth": 24,
18 | "encoder_num_heads": 16,
19 | "drop_path_rate": 0.0,
20 | "num_frames": 32,
21 | "tubelet_size": 1,
22 | "use_checkpoint": false,
23 | "checkpoint_num": 0,
24 | "pretrained": "",
25 | "return_index": -2,
26 | "vit_add_ln": true,
27 | "ckpt_num_frame": 4
28 | },
29 | "num_query_token": 32,
30 | "qformer_hidden_dropout_prob": 0.1,
31 | "qformer_attention_probs_dropout_prob": 0.1,
32 | "qformer_drop_path_rate": 0.2,
33 | "extra_num_query_token": 64,
34 | "qformer_text_input": true,
35 | "system": "",
36 | "start_token": "",
38 | "img_start_token": "",
39 | "img_end_token": "",
40 | "random_shuffle": true,
41 | "use_lora": false,
42 | "lora_r": 16,
43 | "lora_alpha": 32,
44 | "lora_dropout": 0.1
45 | },
46 | "device": "cuda"
47 | }
48 |
--------------------------------------------------------------------------------
/infty-VideoChat2/configs/config_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "fusion_layer": 9,
20 | "encoder_width": 768,
21 | "cross_module": "ca"
22 | }
23 |
--------------------------------------------------------------------------------
/infty-VideoChat2/configs/config_mistral.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "model_cls": "VideoChat2_it_mistral",
4 | "vit_blip_model_path": "/video_chat2/umt_l16_qformer.pth",
5 | "mistral_model_path": "/video_chat2/Mistral-7B-Instruct-v0.2/",
6 | "videochat2_model_path": "/video_chat2/VideoChat2_stage3_Mistral_7B/videochat2_mistral_7b_stage3.pth",
7 | "freeze_vit": false,
8 | "freeze_qformer": false,
9 | "max_txt_len": 512,
10 | "low_resource": false,
11 | "vision_encoder": {
12 | "name": "vit_l14",
13 | "img_size": 224,
14 | "patch_size": 16,
15 | "d_model": 1024,
16 | "encoder_embed_dim": 1024,
17 | "encoder_depth": 24,
18 | "encoder_num_heads": 16,
19 | "drop_path_rate": 0.0,
20 | "num_frames": 8,
21 | "tubelet_size": 1,
22 | "use_checkpoint": true,
23 | "checkpoint_num": 18,
24 | "pretrained": "",
25 | "return_index": -2,
26 | "vit_add_ln": true,
27 | "ckpt_num_frame": 4
28 | },
29 | "num_query_token": 32,
30 | "qformer_hidden_dropout_prob": 0.1,
31 | "qformer_attention_probs_dropout_prob": 0.1,
32 | "qformer_drop_path_rate": 0.2,
33 | "extra_num_query_token": 64,
34 | "qformer_text_input": true,
35 | "system": "",
36 | "start_token": "",
38 | "add_second_msg": true,
39 | "img_start_token": "",
40 | "img_end_token": "",
41 | "random_shuffle": true,
42 | "return_question_instruction": false,
43 | "use_flash_attention": true,
44 | "use_lora": false,
45 | "lora_r": 16,
46 | "lora_alpha": 32,
47 | "lora_dropout": 0.1
48 | },
49 | "device": "cuda"
50 | }
51 |
--------------------------------------------------------------------------------
/infty-VideoChat2/configs/config_phi.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "model_cls": "VideoChat2_it_phi",
4 | "vit_blip_model_path": "your_model_path/videochat2/umt_l16_qformer.pth",
5 | "phi_model_path": "your_model_path/llm/Phi-3-mini-128k-instruct",
6 | "videochat2_model_path": "your_model_path/videochat2/videochat2_phi3_stage2.pth",
7 | "freeze_vit": false,
8 | "freeze_qformer": false,
9 | "max_txt_len": 512,
10 | "low_resource": false,
11 | "vision_encoder": {
12 | "name": "vit_l14",
13 | "img_size": 224,
14 | "patch_size": 16,
15 | "d_model": 1024,
16 | "encoder_embed_dim": 1024,
17 | "encoder_depth": 24,
18 | "encoder_num_heads": 16,
19 | "drop_path_rate": 0.0,
20 | "num_frames": 8,
21 | "tubelet_size": 1,
22 | "use_checkpoint": true,
23 | "checkpoint_num": 18,
24 | "pretrained": "",
25 | "return_index": -2,
26 | "vit_add_ln": true,
27 | "ckpt_num_frame": 4
28 | },
29 | "num_query_token": 32,
30 | "qformer_hidden_dropout_prob": 0.1,
31 | "qformer_attention_probs_dropout_prob": 0.1,
32 | "qformer_drop_path_rate": 0.2,
33 | "extra_num_query_token": 64,
34 | "qformer_text_input": true,
35 | "system": "",
36 | "start_token": "",
38 | "add_second_msg": true,
39 | "img_start_token": "",
40 | "img_end_token": "",
41 | "random_shuffle": true,
42 | "return_question_instruction": false,
43 | "use_flash_attention": true,
44 | "use_lora": false,
45 | "lora_r": 16,
46 | "lora_alpha": 32,
47 | "lora_dropout": 0.1
48 | },
49 | "device": "cuda"
50 | }
51 |
--------------------------------------------------------------------------------
/infty-VideoChat2/configs/data.py:
--------------------------------------------------------------------------------
1 | import os as __os # add "__" if not want to be exported
2 | from copy import deepcopy as __deepcopy
3 |
4 | data_dir = 'your_annotation_path'
5 | if data_dir is None:
6 | raise ValueError("please set environment `VL_DATA_DIR` before continue")
7 |
8 | data_root = __os.path.join(data_dir, "videos_images")
9 | anno_root_pt = __os.path.join(data_dir, "anno_pretrain")
10 |
11 | # ============== pretraining datasets=================
12 | available_corpus = dict(
13 | # pretraining datasets
14 | cc3m=[
15 | f"{anno_root_pt}/cc3m_train.json",
16 | f"{data_root}/cc3m",
17 | ],
18 | cc12m=[
19 | f"{anno_root_pt}/cc12m_train.json",
20 | f"{data_root}/cc12m",
21 | ],
22 | sbu=[
23 | f"{anno_root_pt}/sbu.json",
24 | f"{data_root}/sbu",
25 | ],
26 | vg=[
27 | f"{anno_root_pt}/vg.json",
28 | f"{data_root}/vg",
29 | ],
30 | coco=[
31 | f"{anno_root_pt}/coco.json",
32 | f"{data_root}/coco",
33 | ],
34 | webvid=[
35 | f"{anno_root_pt}/webvid_train.json",
36 | f"{data_root}/webvid",
37 | "video"
38 | ],
39 | webvid_10m=[
40 | f"{anno_root_pt}/webvid_10m_train.json",
41 | f"{data_root}/webvid_10m",
42 | "video",
43 | ],
44 | internvid_10m=[
45 | f"{anno_root_pt}/internvid_10m_train.json",
46 | f"{data_root}/internvid_10m",
47 | "video"
48 | ],
49 | )
50 |
51 | # composed datasets.
52 | available_corpus["msrvtt_1k_test"] = [
53 | f"{anno_root_pt}/msrvtt_test1k.json",
54 | f"{data_root}/MSRVTT_Videos",
55 | "video",
56 | ]
57 |
58 | available_corpus["webvid10m_cc3m"] = [
59 | available_corpus["webvid_10m"],
60 | available_corpus["cc3m"],
61 | ]
62 |
63 | available_corpus["webvid10m_cc14m"] = [
64 | available_corpus["webvid_10m"],
65 | available_corpus["cc3m"],
66 | available_corpus["cc12m"],
67 | ]
68 | available_corpus["webvid10m_cc14m_plus"] = [
69 | available_corpus["webvid_10m"],
70 | available_corpus["cc3m"],
71 | available_corpus["coco"],
72 | available_corpus["vg"],
73 | available_corpus["sbu"],
74 | available_corpus["cc12m"],
75 | available_corpus["internvid_10m"],
76 | ]
--------------------------------------------------------------------------------
/infty-VideoChat2/configs/model.py:
--------------------------------------------------------------------------------
1 | TextEncoders = dict()
2 | TextEncoders["bert"] = dict(
3 | name="bert_base",
4 | pretrained="bert-base-uncased",
5 | config="configs/config_bert.json",
6 | d_model=768,
7 | fusion_layer=9,
8 | )
--------------------------------------------------------------------------------
/infty-VideoChat2/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | from torch.utils.data import Dataset
5 |
6 | from dataset.utils import load_image_from_path
7 | from dataset.hd_utils import HD_transform_padding, HD_transform_no_padding
8 |
9 | try:
10 | from petrel_client.client import Client
11 | has_client = True
12 | except ImportError:
13 | has_client = False
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class ImageVideoBaseDataset(Dataset):
19 | """Base class that implements the image and video loading methods"""
20 |
21 | media_type = "video"
22 |
23 | def __init__(self):
24 | assert self.media_type in ["image", "video", "text"]
25 | self.data_root = None
26 | self.anno_list = (
27 | None # list(dict), each dict contains {"image": str, # image or video path}
28 | )
29 | self.transform = None
30 | self.video_reader = None
31 | self.num_tries = None
32 |
33 | self.client = None
34 | if has_client:
35 | self.client = Client('~/petreloss.conf')
36 |
37 | def __getitem__(self, index):
38 | raise NotImplementedError
39 |
40 | def __len__(self):
41 | raise NotImplementedError
42 |
43 | def get_anno(self, index):
44 | """obtain the annotation for one media (video or image)
45 |
46 | Args:
47 | index (int): The media index.
48 |
49 | Returns: dict.
50 | - "image": the filename, video also use "image".
51 | - "caption": The caption for this file.
52 |
53 | """
54 | anno = self.anno_list[index]
55 | if self.data_root is not None:
56 | anno["image"] = os.path.join(self.data_root, anno["image"])
57 | return anno
58 |
59 | def load_and_transform_media_data(self, index, data_path):
60 | if self.media_type == "image":
61 | return self.load_and_transform_media_data_image(index, data_path)
62 | else:
63 | return self.load_and_transform_media_data_video(index, data_path)
64 |
65 | def load_and_transform_media_data_image(self, index, data_path, dynamic_config=None):
66 | image = load_image_from_path(data_path, client=self.client)
67 |
68 | if dynamic_config:
69 | local_size = dynamic_config["local_size"]
70 | hd_num = dynamic_config["hd_num"]
71 | padding = dynamic_config["padding"]
72 | if padding:
73 | image = HD_transform_padding(image.float(), image_size=local_size, hd_num=hd_num)
74 | else:
75 | image = HD_transform_no_padding(image.float(), image_size=local_size, hd_num=hd_num)
76 |
77 | image = self.transform(image)
78 | return image, index
79 |
80 | def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None, dynamic_config=None):
81 | for _ in range(self.num_tries):
82 | try:
83 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
84 | frames, frame_indices, fps = self.video_reader(
85 | data_path, self.num_frames, self.sample_type,
86 | max_num_frames=max_num_frames, client=self.client, clip=clip
87 | )
88 | except Exception as e:
89 | logger.warning(
90 | f"Caught exception {e} when loading video {data_path}, "
91 | f"randomly sample a new video as replacement"
92 | )
93 | index = random.randint(0, len(self) - 1)
94 | ann = self.get_anno(index)
95 | data_path = ann["image"]
96 | continue
97 |
98 | if dynamic_config:
99 | local_size = dynamic_config["local_size"]
100 | hd_num = dynamic_config["hd_num"]
101 | padding = dynamic_config["padding"]
102 | if padding:
103 | frames = HD_transform_padding(frames.float(), image_size=local_size, hd_num=hd_num)
104 | else:
105 | frames = HD_transform_no_padding(frames.float(), image_size=local_size, hd_num=hd_num)
106 |
107 | # shared aug for video frames
108 | frames = self.transform(frames)
109 | if return_fps:
110 | if fps == None:
111 | sec = None
112 | else:
113 | sec = [str(round(f / fps, 1)) for f in frame_indices]
114 | return frames, index, sec
115 | else:
116 | return frames, index
117 | else:
118 | raise RuntimeError(
119 | f"Failed to fetch video after {self.num_tries} tries. "
120 | f"This might indicate that you have many corrupted videos."
121 | )
122 |
--------------------------------------------------------------------------------
/infty-VideoChat2/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process
4 | import random
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class MetaLoader(object):
11 | """ wraps multiple data loader """
12 | def __init__(self, name2loader):
13 | """Iterates over multiple dataloaders, it ensures all processes
14 | work on data from the same dataloader. This loader will end when
15 | the shorter dataloader raises StopIteration exception.
16 |
17 | loaders: Dict, {name: dataloader}
18 | """
19 | self.name2loader = name2loader
20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()}
21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
22 | index2name = {v: k for k, v in name2index.items()}
23 |
24 | iter_order = []
25 | for n, l in name2loader.items():
26 | iter_order.extend([name2index[n]]*len(l))
27 |
28 | random.shuffle(iter_order)
29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
30 |
31 | # sync
32 | if is_dist_avail_and_initialized():
33 | # make sure all processes have the same order so that
34 | # each step they will have data from the same loader
35 | dist.broadcast(iter_order, src=0)
36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
37 |
38 | logger.info(str(self))
39 |
40 | def __str__(self):
41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
42 | for idx, (name, loader) in enumerate(self.name2loader.items()):
43 | output.append(
44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
45 | )
46 | return "\n".join(output)
47 |
48 | def __len__(self):
49 | return len(self.iter_order)
50 |
51 | def __iter__(self):
52 | """ this iterator will run indefinitely """
53 | for name in self.iter_order:
54 | _iter = self.name2iter[name]
55 | batch = next(_iter)
56 | yield name, batch
57 |
58 |
59 | class MetaLoader_rs(object):
60 | """ wraps multiple data loader """
61 | def __init__(self, name2loader, skip_num=0):
62 | """Iterates over multiple dataloaders, it ensures all processes
63 | work on data from the same dataloader. This loader will end when
64 | the shorter dataloader raises StopIteration exception.
65 |
66 | loaders: Dict, {name: dataloader}
67 | """
68 | self.name2loader = name2loader
69 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
70 | index2name = {v: k for k, v in name2index.items()}
71 |
72 | iter_order = []
73 | for n, l in name2loader.items():
74 | iter_order.extend([name2index[n]]*len(l))
75 |
76 | random.shuffle(iter_order)
77 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
78 |
79 | # sync
80 | if is_dist_avail_and_initialized():
81 | # make sure all processes have the same order so that
82 | # each step they will have data from the same loader
83 | dist.broadcast(iter_order, src=0)
84 |
85 | if skip_num > 0:
86 | iter_order_skip = iter_order[:skip_num]
87 | for k, v in index2name.items():
88 | media_step = (iter_order_skip == k).sum().item()
89 | name2loader[v].sampler.set_start_iter(media_step)
90 | logger.info(f"{v} dataloder skip steps: {media_step}")
91 | iter_order = iter_order[skip_num:]
92 | self.name2loader = name2loader
93 | else:
94 | logger.info("Do not skip steps for any dataloader!")
95 | for k, v in index2name.items():
96 | name2loader[v].sampler.set_start_iter(0)
97 |
98 | self.name2iter = {name: iter(l) for name, l in name2loader.items()}
99 | self.iter_idx = iter_order
100 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
101 |
102 | logger.info(str(self))
103 |
104 | def __str__(self):
105 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
106 | for idx, (name, loader) in enumerate(self.name2loader.items()):
107 | length = (self.iter_idx == idx).sum()
108 | output.append(
109 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={length} "
110 | )
111 | return "\n".join(output)
112 |
113 | def __len__(self):
114 | return len(self.iter_order)
115 |
116 | def __iter__(self):
117 | """ this iterator will run indefinitely """
118 | for name in self.iter_order:
119 | _iter = self.name2iter[name]
120 | batch = next(_iter)
121 | yield name, batch
--------------------------------------------------------------------------------
/infty-VideoChat2/dataset/hd_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import numpy as np
3 | from torch.utils.data._utils.collate import default_collate
4 |
5 |
6 | def HD_transform_padding(frames, image_size=224, hd_num=6):
7 | def _padding_224(frames):
8 | _, _, H, W = frames.shape
9 | tar = int(np.ceil(H / 224) * 224)
10 | top_padding = (tar - H) // 2
11 | bottom_padding = tar - H - top_padding
12 | left_padding = 0
13 | right_padding = 0
14 |
15 | padded_frames = F.pad(
16 | frames,
17 | pad=[left_padding, right_padding, top_padding, bottom_padding],
18 | mode='constant', value=255
19 | )
20 | return padded_frames
21 |
22 | _, _, H, W = frames.shape
23 | trans = False
24 | if W < H:
25 | frames = frames.flip(-2, -1)
26 | trans = True
27 | width, height = H, W
28 | else:
29 | width, height = W, H
30 |
31 | ratio = width / height
32 | scale = 1
33 | while scale * np.ceil(scale / ratio) <= hd_num:
34 | scale += 1
35 | scale -= 1
36 | new_w = int(scale * image_size)
37 | new_h = int(new_w / ratio)
38 |
39 | resized_frames = F.interpolate(
40 | frames, size=(new_h, new_w),
41 | mode='bicubic',
42 | align_corners=False
43 | )
44 | padded_frames = _padding_224(resized_frames)
45 |
46 | if trans:
47 | padded_frames = padded_frames.flip(-2, -1)
48 |
49 | return padded_frames
50 |
51 |
52 | ###############################################
53 | # The above is used in InternLM-XComposer2-HD
54 | # The following is used in InternVL-v1.5
55 | ###############################################
56 |
57 |
58 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
59 | best_ratio_diff = float('inf')
60 | best_ratio = (1, 1)
61 | area = width * height
62 | for ratio in target_ratios:
63 | target_aspect_ratio = ratio[0] / ratio[1]
64 | ratio_diff = abs(aspect_ratio - target_aspect_ratio)
65 | if ratio_diff < best_ratio_diff:
66 | best_ratio_diff = ratio_diff
67 | best_ratio = ratio
68 | elif ratio_diff == best_ratio_diff:
69 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
70 | best_ratio = ratio
71 | return best_ratio
72 |
73 |
74 | def HD_transform_no_padding(frames, image_size=224, hd_num=6):
75 | min_num = 1
76 | max_num = hd_num
77 | _, _, orig_height, orig_width = frames.shape
78 | aspect_ratio = orig_width / orig_height
79 |
80 | # calculate the existing video aspect ratio
81 | target_ratios = set(
82 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
83 | i * j <= max_num and i * j >= min_num)
84 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
85 |
86 | # find the closest aspect ratio to the target
87 | target_aspect_ratio = find_closest_aspect_ratio(
88 | aspect_ratio, target_ratios, orig_width, orig_height, image_size)
89 |
90 | # calculate the target width and height
91 | target_width = image_size * target_aspect_ratio[0]
92 | target_height = image_size * target_aspect_ratio[1]
93 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
94 |
95 | # resize the frames
96 | resized_frame = F.interpolate(
97 | frames, size=(target_height, target_width),
98 | mode='bicubic', align_corners=False
99 | )
100 | return resized_frame
101 |
102 |
103 | def hd_collate_fn(batch):
104 | videos, conversations, instructions, indices = zip(*batch)
105 | videos = [v for v in videos]
106 | conversations = default_collate(conversations)
107 | instructions = default_collate(instructions)
108 | indices = default_collate(indices)
109 | return videos, conversations, instructions, indices
--------------------------------------------------------------------------------
/infty-VideoChat2/dataset/it_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 | import sqlite3
5 | import random
6 | from os.path import basename
7 |
8 | import numpy as np
9 | import datetime
10 |
11 | from dataset.base_dataset import ImageVideoBaseDataset
12 | from dataset.utils import load_anno
13 | from dataset.video_utils import VIDEO_READER_FUNCS
14 | from utils.distributed import is_main_process
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | class ITImgTrainDataset(ImageVideoBaseDataset):
20 | media_type = "image"
21 |
22 | def __init__(
23 | self, ann_file, transform,
24 | system="", role=("Human", "Assistant"),
25 | start_token="", end_token="",
26 | random_shuffle=True, # if True, shuffle the QA list
27 | ):
28 | super().__init__()
29 |
30 | if len(ann_file) == 3 and ann_file[2] == "video":
31 | self.media_type = "video"
32 | else:
33 | self.media_type = "image"
34 | self.label_file, self.data_root = ann_file[:2]
35 |
36 | logger.info('Load json file')
37 | with open(self.label_file, 'r') as f:
38 | self.anno = json.load(f)
39 | self.num_examples = len(self.anno)
40 | self.transform = transform
41 |
42 | # prompt parameters
43 | if system:
44 | assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token."
45 | # currently not support add start_token and end_token in the system, since the msg should be added properly
46 | self.begin_signal = "###"
47 | self.end_signal = " "
48 | self.start_token = start_token
49 | self.end_token = end_token
50 | self.system = system
51 | self.role = role
52 | self.random_shuffle = random_shuffle
53 | # instruction location and number
54 | logger.info(f"Random shuffle: {self.random_shuffle}")
55 |
56 | def get_anno(self, index):
57 | filename = self.anno[index][self.media_type]
58 | qa = self.anno[index]["QA"]
59 | if "start" in self.anno[index] and "end" in self.anno[index]:
60 | anno = {
61 | "image": os.path.join(self.data_root, filename), "qa": qa,
62 | "start": self.anno[index]["start"], "end": self.anno[index]["end"],
63 | }
64 | else:
65 | anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
66 | return anno
67 |
68 | def __len__(self):
69 | return self.num_examples
70 |
71 | def process_qa(self, qa, msg=""):
72 | cur_instruction = ""
73 | # randomly shuffle qa for conversation
74 | if self.random_shuffle and len(qa) > 1:
75 | random.shuffle(qa)
76 | if "i" in qa[0].keys() and qa[0]["i"] != "":
77 | cur_instruction = qa[0]["i"] + self.end_signal
78 |
79 | conversation = self.system
80 | # add instruction as system message
81 | if cur_instruction:
82 | conversation += cur_instruction
83 |
84 | # rstrip() for the extra " " in msg
85 | conversation += (
86 | self.begin_signal + self.role[0] + ": " +
87 | self.start_token + self.end_token + msg.rstrip() + self.end_signal
88 | )
89 |
90 | for sentence in qa:
91 | q = sentence["q"]
92 | a = sentence["a"]
93 | if q != "":
94 | conversation += (self.begin_signal + self.role[0] + ": " + q + self.end_signal)
95 | else:
96 | # no question, often in caption dataset
97 | pass
98 | conversation += (self.begin_signal + self.role[1] + ": " + a + self.end_signal)
99 | conversation += self.begin_signal
100 |
101 | if cur_instruction:
102 | cur_instruction += qa[0]["q"]
103 | return conversation, cur_instruction.strip()
104 |
105 | def __getitem__(self, index):
106 | try:
107 | ann = self.get_anno(index)
108 | image, index = self.load_and_transform_media_data_image(index, ann["image"])
109 | conversation, instruction = self.process_qa(ann["qa"])
110 | return image, conversation, instruction, index
111 | except Exception as e:
112 | logger.warning(f"Caught exception {e} when loading image {ann['image']}")
113 | index = np.random.randint(0, len(self))
114 | return self.__getitem__(index)
115 |
116 |
117 | class ITVidTrainDataset(ITImgTrainDataset):
118 | media_type = "video"
119 |
120 | def __init__(
121 | self, ann_file, transform,
122 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
123 | system="", role=("Human", "Assistant"),
124 | start_token="",
125 | add_second_msg=True,
126 | random_shuffle=True,
127 | ):
128 | super().__init__(
129 | ann_file, transform,
130 | system=system, role=role,
131 | start_token=start_token, end_token=end_token,
132 | random_shuffle=random_shuffle,
133 | )
134 | self.num_frames = num_frames
135 | self.video_reader_type = video_reader_type
136 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
137 | self.sample_type = sample_type
138 | self.num_tries = num_tries
139 | self.add_second_msg = add_second_msg
140 |
141 | logger.info(f"Use {video_reader_type} for data in {ann_file}")
142 | if add_second_msg:
143 | logger.info(f"Add second message: The video contains X frames sampled at T seconds.")
144 |
145 | def __getitem__(self, index):
146 | try:
147 | ann = self.get_anno(index)
148 | msg = ""
149 | clip = None
150 | if "start" in ann and "end" in ann:
151 | clip = [ann["start"], ann["end"]]
152 | video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip)
153 | if self.add_second_msg:
154 | # " " should be added in the start and end
155 | msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
156 | conversation, instruction = self.process_qa(ann["qa"], msg)
157 | return video, conversation, instruction, index
158 | except Exception as e:
159 | logger.warning(f"Caught exception {e} when loading video {ann['image']}")
160 | index = np.random.randint(0, len(self))
161 | return self.__getitem__(index)
--------------------------------------------------------------------------------
/infty-VideoChat2/dataset/pt_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 | import sqlite3
5 | import random
6 | from os.path import basename
7 |
8 | import numpy as np
9 |
10 | from dataset.base_dataset import ImageVideoBaseDataset
11 | from dataset.utils import load_anno, pre_text
12 | from dataset.video_utils import VIDEO_READER_FUNCS
13 | from utils.distributed import is_main_process
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def get_anno_by_id(cur: sqlite3.Cursor, id: int):
19 | """TODO: Docstring for get_anno_by_id.
20 |
21 | Args:
22 | cur (sqlite3.Cursor): The dataset cursor.
23 | id (int): The annotation id.
24 |
25 | Returns:
26 |
27 | """
28 | pass
29 |
30 |
31 | class PTImgTrainDataset(ImageVideoBaseDataset):
32 | media_type = "image"
33 |
34 | def __init__(self, ann_file, transform, pre_text=True):
35 | super().__init__()
36 |
37 | if len(ann_file) == 3 and ann_file[2] == "video":
38 | self.media_type = "video"
39 | else:
40 | self.media_type = "image"
41 | self.label_file, self.data_root = ann_file[:2]
42 |
43 | logger.info('Load json file')
44 | with open(self.label_file, 'r') as f:
45 | self.anno = json.load(f)
46 | self.num_examples = len(self.anno)
47 |
48 | self.transform = transform
49 | self.pre_text = pre_text
50 | logger.info(f"Pre-process text: {pre_text}")
51 |
52 | def get_anno(self, index):
53 | filename = self.anno[index][self.media_type]
54 | caption = self.anno[index]["caption"]
55 | anno = {"image": os.path.join(self.data_root, filename), "caption": caption}
56 | return anno
57 |
58 | def __len__(self):
59 | return self.num_examples
60 |
61 | def __getitem__(self, index):
62 | try:
63 | ann = self.get_anno(index)
64 | image, index = self.load_and_transform_media_data(index, ann["image"])
65 | caption = pre_text(ann["caption"], pre_text=self.pre_text)
66 | return image, caption, index
67 | except Exception as e:
68 | logger.warning(f"Caught exception {e} when loading image {ann['image']}")
69 | index = np.random.randint(0, len(self))
70 | return self.__getitem__(index)
71 |
72 |
73 | class PTVidTrainDataset(PTImgTrainDataset):
74 | media_type = "video"
75 |
76 | def __init__(
77 | self,
78 | ann_file,
79 | transform,
80 | num_frames=4,
81 | video_reader_type="decord",
82 | sample_type="rand",
83 | num_tries=3,
84 | pre_text=True
85 | ):
86 | super().__init__(ann_file, transform, pre_text=pre_text)
87 | self.num_frames = num_frames
88 | self.video_reader_type = video_reader_type
89 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
90 | self.sample_type = sample_type
91 | self.num_tries = num_tries
92 |
93 |
94 | class PTImgEvalDataset(ImageVideoBaseDataset):
95 | media_type = "image"
96 |
97 | def __init__(self, ann_file, transform, has_multi_vision_gt=False):
98 | super(PTImgEvalDataset, self).__init__()
99 | self.raw_anno_list = load_anno(ann_file)
100 | self.transform = transform
101 | self.has_multi_vision_gt = has_multi_vision_gt # each caption has multiple image as ground_truth
102 |
103 | self.text = None
104 | self.image = None
105 | self.txt2img = None
106 | self.img2txt = None
107 | self.build_data()
108 |
109 | def build_data(self):
110 | self.text = []
111 | self.image = []
112 | self.txt2img = {}
113 | self.img2txt = {}
114 | if self.has_multi_vision_gt:
115 | self.build_data_multi_img_gt()
116 | else:
117 | self.build_data_multi_txt_gt()
118 | self.anno_list = [dict(image=e) for e in self.image]
119 |
120 | def build_data_multi_img_gt(self):
121 | """each text may have multiple ground_truth image, e.g., ssv2"""
122 | img_id = 0
123 | for txt_id, ann in enumerate(self.raw_anno_list):
124 | self.text.append(pre_text(ann["caption"]))
125 | self.txt2img[txt_id] = []
126 | _images = ann["image"] \
127 | if isinstance(ann["image"], list) else [ann["image"], ]
128 | for i, image in enumerate(_images):
129 | self.image.append(image)
130 | self.txt2img[txt_id].append(img_id)
131 | self.img2txt[img_id] = txt_id
132 | img_id += 1
133 |
134 | def build_data_multi_txt_gt(self):
135 | """each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
136 | txt_id = 0
137 | for img_id, ann in enumerate(self.raw_anno_list):
138 | self.image.append(ann["image"])
139 | self.img2txt[img_id] = []
140 | _captions = ann["caption"] \
141 | if isinstance(ann["caption"], list) else [ann["caption"], ]
142 | for i, caption in enumerate(_captions):
143 | self.text.append(pre_text(caption))
144 | self.img2txt[img_id].append(txt_id)
145 | self.txt2img[txt_id] = img_id
146 | txt_id += 1
147 |
148 | def __len__(self):
149 | return len(self.anno_list)
150 |
151 | def __getitem__(self, index):
152 | ann = self.anno_list[index]
153 | image, index = self.load_and_transform_media_data(index, ann["image"])
154 | return image, index
155 |
156 |
157 | def preprocess_para_retrieval_data(anno_list):
158 | processed_anno_list = []
159 | for d in anno_list:
160 | d["caption"] = " ".join(d.pop("caption"))
161 | processed_anno_list.append(d)
162 | return processed_anno_list
163 |
164 |
165 | class PTVidEvalDataset(PTImgEvalDataset):
166 | media_type = "video"
167 |
168 | def __init__(
169 | self, ann_file, transform, num_frames=4,
170 | video_reader_type="decord", sample_type="rand", num_tries=1,
171 | is_paragraph_retrieval=False, has_multi_vision_gt=False,
172 | ):
173 | super(PTVidEvalDataset, self).__init__(ann_file, transform, has_multi_vision_gt)
174 | self.num_frames = num_frames
175 | self.video_reader_type = video_reader_type
176 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
177 | self.sample_type = sample_type
178 | self.num_tries = num_tries
179 | self.is_paragraph_retrieval = is_paragraph_retrieval
180 |
181 | if is_paragraph_retrieval:
182 | self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list)
183 | self.build_data()
184 |
--------------------------------------------------------------------------------
/infty-VideoChat2/dataset/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import logging
4 | from torch.utils.data.distributed import DistributedSampler
5 |
6 |
7 | # stolen from https://github.com/facebookresearch/vissl/blob/94def58538d3c7037f5e093196494331eea1a2a2/vissl/data/data_helper.py#L93
8 | class StatefulDistributedSampler(DistributedSampler):
9 | """
10 | More fine-grained state DataSampler that uses training iteration and epoch
11 | both for shuffling data. PyTorch DistributedSampler only uses epoch
12 | for the shuffling and starts sampling data from the start. In case of training
13 | on very large data, we train for one epoch only and when we resume training,
14 | we want to resume the data sampler from the training iteration.
15 | """
16 |
17 | def __init__(self, dataset, batch_size=None, seed: int = 0):
18 | """
19 | Initializes the instance of StatefulDistributedSampler. Random seed is set
20 | for the epoch set and data is shuffled. For starting the sampling, use
21 | the start_iter (set to 0 or set by checkpointing resuming) to
22 | sample data from the remaining images.
23 |
24 | Args:
25 | dataset (Dataset): Pytorch dataset that sampler will shuffle
26 | batch_size (int): batch size we want the sampler to sample
27 | seed (int): Seed for the torch generator.
28 | """
29 | super().__init__(dataset, shuffle=False, seed=seed)
30 |
31 | self.start_iter = 0
32 | self.batch_size = batch_size
33 | self.total_size = len(dataset) - (len(dataset) % self.num_replicas)
34 | self.num_samples = self.total_size // self.num_replicas
35 | print(f"rank: {self.rank}: Sampler created...")
36 |
37 | def __iter__(self):
38 | # partition data into num_replicas and optionally shuffle within a rank
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch + self.seed)
41 | shuffling = torch.randperm(self.num_samples, generator=g).tolist()
42 | indices = np.array(
43 | list(
44 | range(
45 | (self.rank * self.num_samples), (self.rank + 1) * self.num_samples
46 | )
47 | )
48 | )[shuffling].tolist()
49 |
50 | # make sure we have correct number of samples per replica
51 | assert len(indices) == self.num_samples
52 | assert self.batch_size > 0, "batch_size not set for the sampler"
53 |
54 | # resume the sampler
55 | start_index = self.start_iter * self.batch_size
56 | indices = indices[start_index:]
57 | return iter(indices)
58 |
59 | def set_start_iter(self, start_iter):
60 | """
61 | Set the iteration number from which the sampling should start. This is
62 | used to find the marker in the data permutation order from where the
63 | sampler should start sampling.
64 | """
65 | self.start_iter = start_iter
66 |
--------------------------------------------------------------------------------
/infty-VideoChat2/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .videochat2_qformer import VideoChat2_qformer
2 |
3 | from .videochat_vicuna.videochat2_pt_vicuna import VideoChat2_pt_vicuna
4 | from .videochat_vicuna.videochat2_it_vicuna import VideoChat2_it_vicuna
5 |
6 | from .videochat_mistra.videochat2_pt_mistral import VideoChat2_pt_mistral
7 | from .videochat_mistra.videochat2_it_mistral import VideoChat2_it_mistral
8 | from .videochat_mistra.videochat2_it_hd_mistral import VideoChat2_it_hd_mistral
9 |
10 | from .videochat_phi.videochat2_pt_phi import VideoChat2_pt_phi
11 | from .videochat_phi.videochat2_it_phi import VideoChat2_it_phi
--------------------------------------------------------------------------------
/infty-VideoChat2/models/bert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-VideoChat2/models/bert/__init__.py
--------------------------------------------------------------------------------
/infty-VideoChat2/models/bert/builder.py:
--------------------------------------------------------------------------------
1 | from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
2 |
3 | import logging
4 | logger = logging.getLogger(__name__)
5 |
6 | def build_bert(model_config, pretrain, checkpoint):
7 | """build text encoder.
8 |
9 | Args:
10 | model_config (dict): model config.
11 | pretrain (bool): Whether to do pretrain or finetuning.
12 | checkpoint (bool): whether to do gradient_checkpointing.
13 |
14 | Returns: TODO
15 |
16 | """
17 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
18 | bert_config.encoder_width = model_config.vision_encoder.d_model
19 | bert_config.gradient_checkpointing = checkpoint
20 | bert_config.fusion_layer = model_config.text_encoder.fusion_layer
21 |
22 | if not model_config.multimodal.enable:
23 | bert_config.fusion_layer = bert_config.num_hidden_layers
24 |
25 | if pretrain:
26 | text_encoder, loading_info = BertForMaskedLM.from_pretrained(
27 | model_config.text_encoder.pretrained,
28 | config=bert_config,
29 | output_loading_info=True,
30 | )
31 | else:
32 | text_encoder, loading_info = BertModel.from_pretrained(
33 | model_config.text_encoder.pretrained,
34 | config=bert_config,
35 | add_pooling_layer=False,
36 | output_loading_info=True,
37 | )
38 |
39 | return text_encoder
40 |
41 |
42 | def build_bert_decoder(model_config, checkpoint):
43 | """build text decoder the same as the multimodal encoder.
44 |
45 | Args:
46 | model_config (dict): model config.
47 | pretrain (bool): Whether to do pretrain or finetuning.
48 | checkpoint (bool): whether to do gradient_checkpointing.
49 |
50 | Returns: TODO
51 |
52 | """
53 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
54 | bert_config.encoder_width = model_config.vision_encoder.d_model
55 | bert_config.gradient_checkpointing = checkpoint
56 |
57 | bert_config.fusion_layer = 0
58 | bert_config.num_hidden_layers = (
59 | bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
60 | )
61 |
62 | text_decoder, loading_info = BertLMHeadModel.from_pretrained(
63 | model_config.text_encoder.pretrained,
64 | config=bert_config,
65 | output_loading_info=True,
66 | )
67 |
68 | return text_decoder
69 |
--------------------------------------------------------------------------------
/infty-VideoChat2/models/blip2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-VideoChat2/models/blip2/__init__.py
--------------------------------------------------------------------------------
/infty-VideoChat2/models/blip2/blip2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import contextlib
8 | import os
9 | import logging
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | from .Qformer_baseline import BertConfig, BertLMHeadModel_baseline
15 | from .Qformer import BertConfig, BertLMHeadModel
16 | from .vit import build_vit
17 | from transformers import BertTokenizer
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class Blip2Base(nn.Module):
23 | def __init__(self):
24 | super().__init__()
25 |
26 | @classmethod
27 | def init_tokenizer(cls, truncation_side="right"):
28 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side, local_files_only=True)
29 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
30 | return tokenizer
31 |
32 | @property
33 | def device(self):
34 | return list(self.parameters())[0].device
35 |
36 | def maybe_autocast(self, dtype=torch.float16):
37 | # if on cpu, don't use autocast
38 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
39 | enable_autocast = self.device != torch.device("cpu")
40 |
41 | if enable_autocast:
42 | return torch.cuda.amp.autocast(dtype=dtype)
43 | else:
44 | return contextlib.nullcontext()
45 |
46 | @classmethod
47 | def init_Qformer(
48 | cls,
49 | num_query_token, vision_width, tau, alpha, sticky, num_basis, baseline,
50 | qformer_hidden_dropout_prob=0.1,
51 | qformer_attention_probs_dropout_prob=0.1,
52 | qformer_drop_path_rate=0.,
53 | ):
54 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True)
55 | encoder_config.encoder_width = vision_width
56 | encoder_config.sticky = sticky
57 | encoder_config.num_basis = num_basis
58 | encoder_config.tau = tau
59 | encoder_config.alpha = alpha
60 | # insert cross-attention layer every other block
61 | encoder_config.add_cross_attention = True
62 | encoder_config.cross_attention_freq = 2
63 | encoder_config.query_length = num_query_token
64 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
65 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
66 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
67 | logger.info(f"Drop_path:{encoder_config.drop_path_list}")
68 | logger.info(encoder_config)
69 | if baseline:
70 | Qformer = BertLMHeadModel_baseline(config=encoder_config)
71 | else:
72 | Qformer = BertLMHeadModel(config=encoder_config)
73 | query_tokens = nn.Parameter(
74 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
75 | )
76 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
77 | return Qformer, query_tokens
78 |
79 | @classmethod
80 | def init_vision_encoder_umt(self, config):
81 | """build vision encoder
82 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
83 |
84 | """
85 | vision_encoder = build_vit(config)
86 |
87 | if config.vision_encoder.vit_add_ln:
88 | vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12)
89 | else:
90 | vision_layernorm = nn.Identity()
91 |
92 | return vision_encoder, vision_layernorm
93 |
94 |
95 | def disabled_train(self, mode=True):
96 | """Overwrite model.train with this function to make sure train/eval mode
97 | does not change anymore."""
98 | return self
99 |
100 |
101 | class LayerNorm(nn.LayerNorm):
102 | """Subclass torch's LayerNorm to handle fp16."""
103 |
104 | def forward(self, x: torch.Tensor):
105 | orig_type = x.dtype
106 | ret = super().forward(x.type(torch.float32))
107 | return ret.type(orig_type)
108 |
--------------------------------------------------------------------------------
/infty-VideoChat2/models/blip2/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 |
5 |
6 | from .Qformer import BertConfig, BertLMHeadModel
7 | from models.utils import load_temp_embed_with_mismatch
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def build_qformer(num_query_token, vision_width,
13 | qformer_hidden_dropout_prob=0.1,
14 | qformer_attention_probs_dropout_prob=0.1,
15 | drop_path_rate=0.,
16 | ):
17 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True)
18 | encoder_config.encoder_width = vision_width
19 | # insert cross-attention layer every other block
20 | encoder_config.add_cross_attention = True
21 | encoder_config.cross_attention_freq = 2
22 | encoder_config.query_length = num_query_token
23 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
24 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
25 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_config.num_hidden_layers)]
26 | logger.info(f"Drop_path:{encoder_config.drop_path_list}")
27 | logger.info(encoder_config)
28 | Qformer = BertLMHeadModel.from_pretrained(
29 | "bert-base-uncased", config=encoder_config, local_files_only=True
30 | )
31 | query_tokens = nn.Parameter(
32 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
33 | )
34 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
35 | return Qformer, query_tokens
36 |
37 | def interpolate_pos_embed_blip(state_dict, new_model):
38 | if "vision_temp_embed" in state_dict:
39 | vision_temp_embed_new = new_model.state_dict()["vision_temp_embed"]
40 | state_dict["vision_temp_embed"] = load_temp_embed_with_mismatch(
41 | state_dict["vision_temp_embed"], vision_temp_embed_new, add_zero=False
42 | )
43 | return state_dict
44 |
--------------------------------------------------------------------------------
/infty-VideoChat2/models/videochat_mistra/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-VideoChat2/models/videochat_mistra/__init__.py
--------------------------------------------------------------------------------
/infty-VideoChat2/models/videochat_vicuna/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-spin/Infinite-Video/908be519dc63c1b7961795bd46264e71d1736331/infty-VideoChat2/models/videochat_vicuna/__init__.py
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/config_7b_hd_stage4.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "videochat2_instruction_hd"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 8
14 | num_frames_test = 8
15 | batch_size = 3
16 | max_txt_l = 512
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}", text="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}", text="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}", text="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="VideoChat2_it_hd_mistral",
37 | vit_blip_model_path="your_model_path/videochat2/umt_l16_qformer.pth",
38 | mistral_model_path="your_model_path/llm//Mistral-7B-Instruct-v0.2",
39 | videochat2_model_path="your_model_path/videochat2/videochat2_mistral_7b_stage3.pth",
40 | freeze_vit=False,
41 | freeze_qformer=False,
42 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
43 | # vit
44 | low_resource=False,
45 | vision_encoder=dict(
46 | name="vit_l14",
47 | img_size=224,
48 | patch_size=16,
49 | d_model=1024,
50 | encoder_embed_dim=1024,
51 | encoder_depth=24,
52 | encoder_num_heads=16,
53 | drop_path_rate=0.,
54 | num_frames="${num_frames}",
55 | tubelet_size=1,
56 | use_checkpoint=True,
57 | checkpoint_num=24,
58 | pretrained="",
59 | return_index=-2,
60 | vit_add_ln=True,
61 | ckpt_num_frame=4,
62 | ),
63 | # qformer
64 | num_query_token=32,
65 | qformer_hidden_dropout_prob=0.1,
66 | qformer_attention_probs_dropout_prob=0.1,
67 | qformer_drop_path_rate=0.2,
68 | extra_num_query_token=64,
69 | qformer_text_input=True,
70 | # prompt
71 | system="",
72 | start_token="",
74 | add_second_msg=True,
75 | img_start_token="",
76 | img_end_token="",
77 | random_shuffle=True,
78 | return_question_instruction=False,
79 | use_flash_attention=True,
80 | use_lora=True,
81 | lora_r=16,
82 | lora_alpha=32,
83 | lora_dropout=0.1,
84 | # dynamic resolution
85 | dynamic_config=dict(
86 | local_size=224,
87 | hd_num=6,
88 | padding=False,
89 | add_global=True,
90 | ),
91 | # debug=True,
92 | )
93 |
94 | optimizer = dict(
95 | opt="adamW",
96 | lr=1e-5,
97 | opt_betas=[0.9, 0.999], # default
98 | weight_decay=0.02,
99 | max_grad_norm=-1, # requires a positive float, use -1 to disable
100 | # use a different lr for some modules, e.g., larger lr for new modules
101 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
102 | )
103 |
104 | scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.)
105 |
106 | evaluate = False
107 | deep_fusion = False
108 | evaluation = dict(
109 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
110 | eval_x_only=False,
111 | k_test=128,
112 | eval_offload=True, # offload gpu tensors to cpu to save memory.
113 | )
114 |
115 | fp16 = True
116 | gradient_checkpointing = True
117 |
118 | # ========================= wandb ==========================
119 | wandb = dict(
120 | enable=False,
121 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
122 | project="videochat2", # setup in your command line
123 | )
124 | dist_url = "env://"
125 | device = "cuda"
126 | mode = "it_mistral"
127 |
128 | # ========================= others ==========================
129 | output_dir = None # output dir
130 | resume = False # if True, load optimizer and scheduler states as well
131 | debug = False
132 | log_freq = 10
133 | seed = 42
134 |
135 | save_iter = 0
136 | save_latest = True
137 | auto_resume = True
138 | pretrained_path = "" # path to pretrained model weights, for resume only?
139 |
140 | deepspeed = dict(
141 | enable=True,
142 | stage=1,
143 | )
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/config_7b_stage2.py:
--------------------------------------------------------------------------------
1 | from configs.data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "webvid10m_cc3m"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 4
14 | num_frames_test = 4
15 | batch_size = 16
16 | max_txt_l = 32
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="VideoChat2_pt_mistral",
37 | vit_blip_model_path="your_model_path/videochat2/umt_l16_qformer.pth",
38 | mistral_model_path="your_model_path/llm/Mistral-7B-Instruct-v0.2",
39 | gpt_model_path="",
40 | freeze_vit=False,
41 | freeze_qformer=False,
42 | # vit
43 | low_resource=False,
44 | vision_encoder=dict(
45 | name="vit_l14",
46 | img_size=224,
47 | patch_size=16,
48 | d_model=1024,
49 | encoder_embed_dim=1024,
50 | encoder_depth=24,
51 | encoder_num_heads=16,
52 | drop_path_rate=0.,
53 | num_frames="${num_frames}",
54 | tubelet_size=1,
55 | use_checkpoint=False,
56 | checkpoint_num=0,
57 | pretrained="",
58 | return_index=-2,
59 | vit_add_ln=True,
60 | ),
61 | # prompt
62 | prompt_path="prompts/concise_description.txt",
63 | img_prompt_path="prompts/concise_image_description.txt",
64 | prompt_template="[INST] {} [/INST]",
65 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage2
66 | end_sym="",
67 | # qformers
68 | num_query_token=32,
69 | qformer_hidden_dropout_prob=0.1,
70 | qformer_attention_probs_dropout_prob=0.1,
71 | qformer_drop_path_rate=0.2,
72 | extra_num_query_token=64,
73 | # debug=True,
74 | )
75 |
76 | optimizer = dict(
77 | opt="adamW",
78 | lr=1e-4,
79 | opt_betas=[0.9, 0.999], # default
80 | weight_decay=0.02,
81 | max_grad_norm=-1, # requires a positive float, use -1 to disable
82 | # use a different lr for some modules, e.g., larger lr for new modules
83 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
84 | )
85 |
86 | scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.2)
87 |
88 | evaluate = False
89 | deep_fusion = False
90 | evaluation = dict(
91 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
92 | eval_x_only=False,
93 | k_test=128,
94 | eval_offload=True, # offload gpu tensors to cpu to save memory.
95 | )
96 |
97 | fp16 = True
98 | gradient_checkpointing = True
99 |
100 | # ========================= wandb ==========================
101 | wandb = dict(
102 | enable=False,
103 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
104 | project="videochat2", # setup in your command line
105 | )
106 | dist_url = "env://"
107 | device = "cuda"
108 | mode = "pt"
109 |
110 | # ========================= others ==========================
111 | output_dir = None # output dir
112 | resume = False # if True, load optimizer and scheduler states as well
113 | debug = False
114 | log_freq = 100
115 | seed = 42
116 |
117 | save_latest = True
118 | auto_resume = True
119 | pretrained_path = "" # path to pretrained model weights, for resume only?
120 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/config_7b_stage3.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "videochat2_instruction_new"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 8
14 | num_frames_test = 8
15 | batch_size = 4
16 | max_txt_l = 512
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="VideoChat2_it_mistral",
37 | vit_blip_model_path="your_model_path/videochat2/umt_l16_qformer.pth",
38 | mistral_model_path="your_model_path/llm//Mistral-7B-Instruct-v0.2",
39 | videochat2_model_path="your_model_path/videochat2/videochat2_mistral_7b_stage2.pth",
40 | freeze_vit=False,
41 | freeze_qformer=False,
42 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
43 | # vit
44 | low_resource=False,
45 | vision_encoder=dict(
46 | name="vit_l14",
47 | img_size=224,
48 | patch_size=16,
49 | d_model=1024,
50 | encoder_embed_dim=1024,
51 | encoder_depth=24,
52 | encoder_num_heads=16,
53 | drop_path_rate=0.,
54 | num_frames="${num_frames}",
55 | tubelet_size=1,
56 | use_checkpoint=True,
57 | checkpoint_num=18,
58 | pretrained="",
59 | return_index=-2,
60 | vit_add_ln=True,
61 | ckpt_num_frame=4,
62 | ),
63 | # qformer
64 | num_query_token=32,
65 | qformer_hidden_dropout_prob=0.1,
66 | qformer_attention_probs_dropout_prob=0.1,
67 | qformer_drop_path_rate=0.2,
68 | extra_num_query_token=64,
69 | qformer_text_input=True,
70 | # prompt
71 | system="",
72 | start_token="",
74 | add_second_msg=True,
75 | img_start_token="",
76 | img_end_token="",
77 | random_shuffle=True,
78 | use_flash_attention=True,
79 | use_lora=True,
80 | lora_r=16,
81 | lora_alpha=32,
82 | lora_dropout=0.1,
83 | # debug=True,
84 | )
85 |
86 | optimizer = dict(
87 | opt="adamW",
88 | lr=2e-5,
89 | opt_betas=[0.9, 0.999], # default
90 | weight_decay=0.02,
91 | max_grad_norm=-1, # requires a positive float, use -1 to disable
92 | # use a different lr for some modules, e.g., larger lr for new modules
93 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
94 | )
95 |
96 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
97 |
98 | evaluate = False
99 | deep_fusion = False
100 | evaluation = dict(
101 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
102 | eval_x_only=False,
103 | k_test=128,
104 | eval_offload=True, # offload gpu tensors to cpu to save memory.
105 | )
106 |
107 | fp16 = True
108 | gradient_checkpointing = True
109 |
110 | # ========================= wandb ==========================
111 | wandb = dict(
112 | enable=False,
113 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
114 | project="videochat2", # setup in your command line
115 | )
116 | dist_url = "env://"
117 | device = "cuda"
118 | mode = "it_mistral"
119 |
120 | # ========================= others ==========================
121 | output_dir = None # output dir
122 | resume = False # if True, load optimizer and scheduler states as well
123 | debug = False
124 | log_freq = 10
125 | seed = 42
126 |
127 | save_latest = True
128 | auto_resume = True
129 | pretrained_path = "" # path to pretrained model weights, for resume only?
130 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/run_7b_hd_stage4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Please modify the ${MASTER_NODE}:${MASTER_PORT}
3 | MASTER_NODE=127.0.0.1
4 | MASTER_PORT=$((10000 + $RANDOM % 100))
5 | NNODE=1
6 | NUM_GPUS=8
7 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
8 |
9 | torchrun --rdzv_endpoint=${MASTER_NODE}:${MASTER_PORT} \
10 | --nnodes=${NNODE} \
11 | --nproc_per_node=${NUM_GPUS} \
12 | --rdzv_backend=c10d \
13 | tasks/train_it_ds.py \
14 | $(dirname $0)/config_7b_hd_stage4.py \
15 | output_dir ${OUTPUT_DIR}
16 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/run_7b_stage2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Please modify the ${MASTER_NODE}:${MASTER_PORT}
3 | MASTER_NODE=127.0.0.1
4 | MASTER_PORT=$((10000 + $RANDOM % 100))
5 | NNODE=1
6 | NUM_GPUS=8
7 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
8 |
9 | torchrun --rdzv_endpoint=${MASTER_NODE}:${MASTER_PORT} \
10 | --nnodes=${NNODE} \
11 | --nproc_per_node=${NUM_GPUS} \
12 | --rdzv_backend=c10d \
13 | tasks/train_pt.py \
14 | $(dirname $0)/config_7b_stage2.py \
15 | output_dir ${OUTPUT_DIR}
16 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/run_7b_stage3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Please modify the ${MASTER_NODE}:${MASTER_PORT}
3 | MASTER_NODE=127.0.0.1
4 | MASTER_PORT=$((10000 + $RANDOM % 100))
5 | NNODE=1
6 | NUM_GPUS=8
7 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
8 |
9 | torchrun --rdzv_endpoint=${MASTER_NODE}:${MASTER_PORT} \
10 | --nnodes=${NNODE} \
11 | --nproc_per_node=${NUM_GPUS} \
12 | --rdzv_backend=c10d \
13 | tasks/train_it.py \
14 | $(dirname $0)/config_7b_stage3.py \
15 | output_dir ${OUTPUT_DIR}
16 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/slurm_run_7b_stage2.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='stage2'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_GPUS=8
15 | NUM_CPU=128
16 |
17 | srun -p ${PARTITION} \
18 | -n${NNODE} \
19 | --gres=gpu:${NUM_GPUS} \
20 | --ntasks-per-node=1 \
21 | --cpus-per-task=${NUM_CPU} \
22 | bash torchrun.sh \
23 | --nnodes=${NNODE} \
24 | --nproc_per_node=${NUM_GPUS} \
25 | --rdzv_backend=c10d \
26 | tasks/train_pt.py \
27 | $(dirname $0)/config_7b_stage2.py \
28 | output_dir ${OUTPUT_DIR}
29 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/slurm_run_7b_stage3.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='stage3'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_GPUS=8
15 | NUM_CPU=128
16 |
17 | srun -p ${PARTITION} \
18 | -n${NNODE} \
19 | --gres=gpu:${NUM_GPUS} \
20 | --ntasks-per-node=1 \
21 | --cpus-per-task=${NUM_CPU} \
22 | bash torchrun.sh \
23 | --nnodes=${NNODE} \
24 | --nproc_per_node=${NUM_GPUS} \
25 | --rdzv_backend=c10d \
26 | tasks/train_it.py \
27 | $(dirname $0)/config_7b_stage3.py \
28 | output_dir ${OUTPUT_DIR}
29 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_mistral/slurm_run_7b_stage4_hd.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='hd_stage4'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_GPUS=8
15 | NUM_CPU=128
16 |
17 | srun -p ${PARTITION} \
18 | -n${NNODE} \
19 | --gres=gpu:${NUM_GPUS} \
20 | --ntasks-per-node=1 \
21 | --cpus-per-task=${NUM_CPU} \
22 | bash torchrun.sh \
23 | --nnodes=${NNODE} \
24 | --nproc_per_node=${NUM_GPUS} \
25 | --rdzv_backend=c10d \
26 | tasks/train_it_ds.py \
27 | $(dirname $0)/config_7b_hd_stage4.py \
28 | output_dir ${OUTPUT_DIR}
29 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_phi/config_7b_stage2.py:
--------------------------------------------------------------------------------
1 | from configs.data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "webvid10m_cc3m"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 4
14 | num_frames_test = 4
15 | batch_size = 20
16 | max_txt_l = 32
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="VideoChat2_it_phi",
37 | vit_blip_model_path="your_model_path/videochat2/umt_l16_qformer.pth",
38 | mistral_model_path="your_model_path/llm/Phi-3-mini-128k-instruct",
39 | gpt_model_path="",
40 | freeze_vit=False,
41 | freeze_qformer=False,
42 | # vit
43 | low_resource=False,
44 | vision_encoder=dict(
45 | name="vit_l14",
46 | img_size=224,
47 | patch_size=16,
48 | d_model=1024,
49 | encoder_embed_dim=1024,
50 | encoder_depth=24,
51 | encoder_num_heads=16,
52 | drop_path_rate=0.,
53 | num_frames="${num_frames}",
54 | tubelet_size=1,
55 | use_checkpoint=False,
56 | checkpoint_num=0,
57 | pretrained="",
58 | return_index=-2,
59 | vit_add_ln=True,
60 | ),
61 | # prompt
62 | prompt_path="prompts/concise_description.txt",
63 | img_prompt_path="prompts/concise_image_description.txt",
64 | prompt_template="<|user|>\n{}<|end|>\n<|assistant|>\n",
65 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage2
66 | end_sym="<|end|>",
67 | # qformers
68 | num_query_token=32,
69 | qformer_hidden_dropout_prob=0.1,
70 | qformer_attention_probs_dropout_prob=0.1,
71 | qformer_drop_path_rate=0.2,
72 | extra_num_query_token=64,
73 | # debug=True,
74 | )
75 |
76 | optimizer = dict(
77 | opt="adamW",
78 | lr=1e-4,
79 | opt_betas=[0.9, 0.999], # default
80 | weight_decay=0.02,
81 | max_grad_norm=-1, # requires a positive float, use -1 to disable
82 | # use a different lr for some modules, e.g., larger lr for new modules
83 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
84 | )
85 |
86 | scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.2)
87 |
88 | evaluate = False
89 | deep_fusion = False
90 | evaluation = dict(
91 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
92 | eval_x_only=False,
93 | k_test=128,
94 | eval_offload=True, # offload gpu tensors to cpu to save memory.
95 | )
96 |
97 | fp16 = True
98 | gradient_checkpointing = True
99 |
100 | # ========================= wandb ==========================
101 | wandb = dict(
102 | enable=False,
103 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
104 | project="videochat2", # setup in your command line
105 | )
106 | dist_url = "env://"
107 | device = "cuda"
108 | mode = "pt"
109 |
110 | # ========================= others ==========================
111 | output_dir = None # output dir
112 | resume = False # if True, load optimizer and scheduler states as well
113 | debug = False
114 | log_freq = 100
115 | seed = 42
116 |
117 | save_latest = True
118 | auto_resume = True
119 | pretrained_path = "" # path to pretrained model weights, for resume only?
120 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_phi/config_7b_stage3.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "videochat2_instruction_new"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 8
14 | num_frames_test = 8
15 | batch_size = 8
16 | max_txt_l = 512
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="VideoChat2_it_phi",
37 | vit_blip_model_path="your_model_path/videochat2/umt_l16_qformer.pth",
38 | mistral_model_path="your_model_path/llm//Phi-3-mini-128k-instruct",
39 | videochat2_model_path="your_model_path/videochat2/videochat2_phi3_stage2.pth",
40 | freeze_vit=False,
41 | freeze_qformer=False,
42 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
43 | # vit
44 | low_resource=False,
45 | vision_encoder=dict(
46 | name="vit_l14",
47 | img_size=224,
48 | patch_size=16,
49 | d_model=1024,
50 | encoder_embed_dim=1024,
51 | encoder_depth=24,
52 | encoder_num_heads=16,
53 | drop_path_rate=0.,
54 | num_frames="${num_frames}",
55 | tubelet_size=1,
56 | use_checkpoint=True,
57 | checkpoint_num=18,
58 | pretrained="",
59 | return_index=-2,
60 | vit_add_ln=True,
61 | ckpt_num_frame=4,
62 | ),
63 | # qformer
64 | num_query_token=32,
65 | qformer_hidden_dropout_prob=0.1,
66 | qformer_attention_probs_dropout_prob=0.1,
67 | qformer_drop_path_rate=0.2,
68 | extra_num_query_token=64,
69 | qformer_text_input=True,
70 | # prompt
71 | system="",
72 | start_token="",
74 | add_second_msg=True,
75 | img_start_token="",
76 | img_end_token="",
77 | random_shuffle=True,
78 | use_flash_attention=True,
79 | use_lora=True,
80 | lora_r=16,
81 | lora_alpha=32,
82 | lora_dropout=0.1,
83 | # debug=True,
84 | )
85 |
86 | optimizer = dict(
87 | opt="adamW",
88 | lr=2e-5,
89 | opt_betas=[0.9, 0.999], # default
90 | weight_decay=0.02,
91 | max_grad_norm=-1, # requires a positive float, use -1 to disable
92 | # use a different lr for some modules, e.g., larger lr for new modules
93 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
94 | )
95 |
96 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
97 |
98 | evaluate = False
99 | deep_fusion = False
100 | evaluation = dict(
101 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
102 | eval_x_only=False,
103 | k_test=128,
104 | eval_offload=True, # offload gpu tensors to cpu to save memory.
105 | )
106 |
107 | fp16 = True
108 | gradient_checkpointing = True
109 |
110 | # ========================= wandb ==========================
111 | wandb = dict(
112 | enable=False,
113 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
114 | project="videochat2", # setup in your command line
115 | )
116 | dist_url = "env://"
117 | device = "cuda"
118 | mode = "it_mistral"
119 |
120 | # ========================= others ==========================
121 | output_dir = None # output dir
122 | resume = False # if True, load optimizer and scheduler states as well
123 | debug = False
124 | log_freq = 10
125 | seed = 42
126 |
127 | save_latest = True
128 | auto_resume = True
129 | pretrained_path = "" # path to pretrained model weights, for resume only?
130 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_phi/run_7b_stage2.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='stage2'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_GPUS=8
15 | NUM_CPU=128
16 |
17 | srun -p ${PARTITION} \
18 | -n${NNODE} \
19 | --gres=gpu:${NUM_GPUS} \
20 | --ntasks-per-node=1 \
21 | --cpus-per-task=${NUM_CPU} \
22 | bash torchrun.sh \
23 | --nnodes=${NNODE} \
24 | --nproc_per_node=${NUM_GPUS} \
25 | --rdzv_backend=c10d \
26 | tasks/train_pt.py \
27 | $(dirname $0)/config_7b_stage2.py \
28 | output_dir ${OUTPUT_DIR}
29 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_phi/run_7b_stage3.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='stage3'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_GPUS=8
15 | NUM_CPU=128
16 |
17 | srun -p ${PARTITION} \
18 | -n${NNODE} \
19 | --gres=gpu:${NUM_GPUS} \
20 | --ntasks-per-node=1 \
21 | --cpus-per-task=${NUM_CPU} \
22 | bash torchrun.sh \
23 | --nnodes=${NNODE} \
24 | --nproc_per_node=${NUM_GPUS} \
25 | --rdzv_backend=c10d \
26 | tasks/train_it.py \
27 | $(dirname $0)/config_7b_stage3.py \
28 | output_dir ${OUTPUT_DIR}
29 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/config_7b_stage1.py:
--------------------------------------------------------------------------------
1 | from configs.data import *
2 | from configs.model import *
3 |
4 | # ========================= data ==========================
5 | train_corpus = "webvid10m_cc14m"
6 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
7 | test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"])
8 | test_types = ["msrvtt_1k_test"]
9 |
10 | num_workers = 6
11 |
12 | stop_key = None
13 |
14 | # ========================= input ==========================
15 | num_frames = 4
16 | num_frames_test = 4
17 | batch_size = 128
18 | max_txt_l = 32
19 |
20 | pre_text = False
21 |
22 | inputs = dict(
23 | image_res=224,
24 | video_input=dict(
25 | num_frames="${num_frames}",
26 | sample_type="rand",
27 | num_frames_test="${num_frames_test}",
28 | sample_type_test="middle",
29 | random_aug=False,
30 | ),
31 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
32 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
33 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
34 | )
35 |
36 | # ========================= model ==========================
37 | text_enc = "bert"
38 | model = dict(
39 | model_cls="VideoChat2_qformer",
40 | vision_encoder=dict(
41 | name="vit_l14",
42 | img_size=224,
43 | patch_size=16,
44 | d_model=1024,
45 | encoder_embed_dim=1024,
46 | encoder_depth=24,
47 | encoder_num_heads=16,
48 | drop_path_rate=0.,
49 | num_frames="${num_frames}",
50 | tubelet_size=1,
51 | use_checkpoint=False,
52 | checkpoint_num=12,
53 | pretrained="/mnt/petrelfs/share_data/likunchang/model/videochat2/l16_25m.pth",
54 | return_index=-2,
55 | ),
56 | text_encoder="${TextEncoders[${text_enc}]}",
57 | vit_add_ln=True,
58 | embed_dim=768,
59 | temp=0.07,
60 | qformer_num_query_tokens=32,
61 | agg_method="mean",
62 | drop_path_rate=0.2,
63 | )
64 |
65 | criterion = dict(
66 | loss_weight=dict(vtc=1.0, mlm=0.0, vtm=1.0, mvm=0.0, cap=1.0), # 0: disabled.
67 | vtm_hard_neg=True,
68 | vtm_cat_text_cls=True
69 | )
70 |
71 | optimizer = dict(
72 | opt="adamW",
73 | lr=1e-4,
74 | opt_betas=[0.9, 0.999], # default
75 | weight_decay=0.02,
76 | max_grad_norm=-1, # requires a positive float, use -1 to disable
77 | # use a different lr for some modules, e.g., larger lr for new modules
78 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
79 | )
80 |
81 | scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.01, warmup_epochs=0.2)
82 |
83 | evaluate = False
84 | deep_fusion = False
85 | evaluation = dict(
86 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
87 | eval_x_only=False,
88 | k_test=128,
89 | eval_offload=True, # offload gpu tensors to cpu to save memory.
90 | )
91 |
92 | fp16 = True
93 | gradient_checkpointing = True
94 |
95 | # ========================= wandb ==========================
96 | wandb = dict(
97 | enable=False,
98 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
99 | project="videochat2", # setup in your command line
100 | )
101 | dist_url = "env://"
102 | device = "cuda"
103 | mode = "pt"
104 |
105 | # ========================= others ==========================
106 | output_dir = None # output dir
107 | resume = False # if True, load optimizer and scheduler states as well
108 | debug = False
109 | log_freq = 100
110 | seed = 42
111 |
112 | save_latest = True
113 | auto_resume = True
114 | pretrained_path = "" # path to pretrained model weights, for resume only?
115 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/config_7b_stage2.py:
--------------------------------------------------------------------------------
1 | from configs.data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "webvid10m_cc14m_plus"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 8
14 | num_frames_test = 8
15 | batch_size = 4
16 | max_txt_l = 512
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="VideoChat2_pt_vicuna",
37 | vit_blip_model_path="your_model_path/videochat2/umt_l16_qformer.pth",
38 | llama_model_path="your_model_path/llm/vicuna-7b-v0",
39 | freeze_vit=False,
40 | freeze_qformer=False,
41 | max_txt_len="${max_txt_l}",
42 | # vit
43 | low_resource=False,
44 | vision_encoder=dict(
45 | name="vit_l14",
46 | img_size=224,
47 | patch_size=16,
48 | d_model=1024,
49 | encoder_embed_dim=1024,
50 | encoder_depth=24,
51 | encoder_num_heads=16,
52 | drop_path_rate=0.,
53 | num_frames="${num_frames}",
54 | tubelet_size=1,
55 | use_checkpoint=False,
56 | checkpoint_num=0,
57 | pretrained="",
58 | return_index=-2,
59 | vit_add_ln=True,
60 | ),
61 | # prompt
62 | prompt_path="prompts/concise_description.txt",
63 | img_prompt_path="prompts/concise_image_description.txt",
64 | prompt_template="###Human: {} ###Assistant: ",
65 | end_sym="###",
66 | # qformer
67 | num_query_token=32,
68 | qformer_hidden_dropout_prob=0.1,
69 | qformer_attention_probs_dropout_prob=0.1,
70 | qformer_drop_path_rate=0.2,
71 | extra_num_query_token=64,
72 | # debug=True,
73 | )
74 |
75 | optimizer = dict(
76 | opt="adamW",
77 | lr=1e-4,
78 | opt_betas=[0.9, 0.999], # default
79 | weight_decay=0.02,
80 | max_grad_norm=-1, # requires a positive float, use -1 to disable
81 | # use a different lr for some modules, e.g., larger lr for new modules
82 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
83 | )
84 |
85 | scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.2)
86 |
87 | evaluate = False
88 | deep_fusion = False
89 | evaluation = dict(
90 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
91 | eval_x_only=False,
92 | k_test=128,
93 | eval_offload=True, # offload gpu tensors to cpu to save memory.
94 | )
95 |
96 | fp16 = True
97 | gradient_checkpointing = True
98 |
99 | # ========================= wandb ==========================
100 | wandb = dict(
101 | enable=False,
102 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
103 | project="videochat2", # setup in your command line
104 | )
105 | dist_url = "env://"
106 | device = "cuda"
107 | mode = "pt"
108 |
109 | # ========================= others ==========================
110 | output_dir = None # output dir
111 | resume = False # if True, load optimizer and scheduler states as well
112 | debug = False
113 | log_freq = 100
114 | seed = 42
115 |
116 | save_latest = True
117 | auto_resume = True
118 | pretrained_path = "" # path to pretrained model weights, for resume only?
119 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/config_7b_stage3.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "videochat2_instruction"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 8
14 | num_frames_test = 8
15 | batch_size = 4
16 | max_txt_l = 512
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="VideoChat2_it_vicuna",
37 | vit_blip_model_path="your_model_path/videochat2/umt_l16_qformer.pth",
38 | llama_model_path="your_model_path/llm/vicuna-7b-v0",
39 | videochat2_model_path="your_model_path/videochat2/videochat2_7b_stage2.pth",
40 | freeze_vit=False,
41 | freeze_qformer=False,
42 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
43 | # vit
44 | low_resource=False,
45 | vision_encoder=dict(
46 | name="vit_l14",
47 | img_size=224,
48 | patch_size=16,
49 | d_model=1024,
50 | encoder_embed_dim=1024,
51 | encoder_depth=24,
52 | encoder_num_heads=16,
53 | drop_path_rate=0.,
54 | num_frames="${num_frames}",
55 | tubelet_size=1,
56 | use_checkpoint=False,
57 | checkpoint_num=0,
58 | pretrained="",
59 | return_index=-2,
60 | vit_add_ln=True,
61 | ckpt_num_frame=4,
62 | ),
63 | # qformer
64 | num_query_token=32,
65 | qformer_hidden_dropout_prob=0.1,
66 | qformer_attention_probs_dropout_prob=0.1,
67 | qformer_drop_path_rate=0.2,
68 | extra_num_query_token=64,
69 | qformer_text_input=True,
70 | # prompt
71 | system="",
72 | start_token="",
74 | add_second_msg=True,
75 | img_start_token="",
76 | img_end_token="",
77 | random_shuffle=True,
78 | use_flash_attention=True,
79 | use_lora=True,
80 | lora_r=16,
81 | lora_alpha=32,
82 | lora_dropout=0.1,
83 | # debug=True,
84 | )
85 |
86 | optimizer = dict(
87 | opt="adamW",
88 | lr=2e-5,
89 | opt_betas=[0.9, 0.999], # default
90 | weight_decay=0.02,
91 | max_grad_norm=-1, # requires a positive float, use -1 to disable
92 | # use a different lr for some modules, e.g., larger lr for new modules
93 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
94 | )
95 |
96 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
97 |
98 | evaluate = False
99 | deep_fusion = False
100 | evaluation = dict(
101 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
102 | eval_x_only=False,
103 | k_test=128,
104 | eval_offload=True, # offload gpu tensors to cpu to save memory.
105 | )
106 |
107 | fp16 = True
108 | gradient_checkpointing = True
109 |
110 | # ========================= wandb ==========================
111 | wandb = dict(
112 | enable=False,
113 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
114 | project="videochat2", # setup in your command line
115 | )
116 | dist_url = "env://"
117 | device = "cuda"
118 | mode = "it"
119 |
120 | # ========================= others ==========================
121 | output_dir = None # output dir
122 | resume = False # if True, load optimizer and scheduler states as well
123 | debug = False
124 | log_freq = 100
125 | seed = 42
126 |
127 | save_latest = True
128 | auto_resume = True
129 | pretrained_path = "" # path to pretrained model weights, for resume only?
130 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/run_7b_stage1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Please modify the ${MASTER_NODE}:${MASTER_PORT}
3 | MASTER_NODE=127.0.0.1
4 | MASTER_PORT=$((10000 + $RANDOM % 100))
5 | NNODE=1
6 | NUM_GPUS=8
7 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
8 |
9 |
10 | torchrun --rdzv_endpoint=${MASTER_NODE}:${MASTER_PORT} \
11 | --nnodes=${NNODE} \
12 | --nproc_per_node=${NUM_GPUS} \
13 | --rdzv_backend=c10d \
14 | tasks/train_qformer.py \
15 | $(dirname $0)/config_7b_stage1.py \
16 | output_dir ${OUTPUT_DIR}
17 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/run_7b_stage2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Please modify the ${MASTER_NODE}:${MASTER_PORT}
3 | MASTER_NODE=127.0.0.1
4 | MASTER_PORT=$((10000 + $RANDOM % 100))
5 | NNODE=1
6 | NUM_GPUS=8
7 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
8 |
9 |
10 | torchrun --rdzv_endpoint=${MASTER_NODE}:${MASTER_PORT} \
11 | --nnodes=${NNODE} \
12 | --nproc_per_node=${NUM_GPUS} \
13 | --rdzv_backend=c10d \
14 | tasks/train_pt.py \
15 | $(dirname $0)/config_7b_stage2.py \
16 | output_dir ${OUTPUT_DIR}
17 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/run_7b_stage3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Please modify the ${MASTER_NODE}:${MASTER_PORT}
3 | MASTER_NODE=127.0.0.1
4 | MASTER_PORT=$((10000 + $RANDOM % 100))
5 | NNODE=1
6 | NUM_GPUS=8
7 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
8 |
9 |
10 | torchrun --rdzv_endpoint=${MASTER_NODE}:${MASTER_PORT} \
11 | --nnodes=${NNODE} \
12 | --nproc_per_node=${NUM_GPUS} \
13 | --rdzv_backend=c10d \
14 | tasks/train_it.py \
15 | $(dirname $0)/config_7b_stage3.py \
16 | output_dir ${OUTPUT_DIR}
17 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/slurm_run_7b_stage1.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='stage1'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_CPU=128
15 |
16 | srun -p ${PARTITION} \
17 | -n${NNODE} \
18 | --gres=gpu:${NUM_GPUS} \
19 | --ntasks-per-node=1 \
20 | --cpus-per-task=${NUM_CPU} \
21 | bash torchrun.sh \
22 | --nnodes=${NNODE} \
23 | --nproc_per_node=${NUM_GPUS} \
24 | --rdzv_backend=c10d \
25 | tasks/train_qformer.py \
26 | $(dirname $0)/config_7b_stage1.py \
27 | output_dir ${OUTPUT_DIR}
28 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/slurm_run_7b_stage2.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='stage2'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_GPUS=8
15 | NUM_CPU=128
16 |
17 | srun -p ${PARTITION} \
18 | -n${NNODE} \
19 | --gres=gpu:${NUM_GPUS} \
20 | --ntasks-per-node=1 \
21 | --cpus-per-task=${NUM_CPU} \
22 | bash torchrun.sh \
23 | --nnodes=${NNODE} \
24 | --nproc_per_node=${NUM_GPUS} \
25 | --rdzv_backend=c10d \
26 | tasks/train_pt.py \
27 | $(dirname $0)/config_7b_stage2.py \
28 | output_dir ${OUTPUT_DIR}
29 |
--------------------------------------------------------------------------------
/infty-VideoChat2/scripts/videochat_vicuna/slurm_run_7b_stage3.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | JOB_NAME='stage3'
11 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
12 | PARTITION='video'
13 | NNODE=4
14 | NUM_GPUS=8
15 | NUM_CPU=128
16 |
17 | srun -p ${PARTITION} \
18 | -n${NNODE} \
19 | --gres=gpu:${NUM_GPUS} \
20 | --ntasks-per-node=1 \
21 | --cpus-per-task=${NUM_CPU} \
22 | bash torchrun.sh \
23 | --nnodes=${NNODE} \
24 | --nproc_per_node=${NUM_GPUS} \
25 | --rdzv_backend=c10d \
26 | tasks/train_it.py \
27 | $(dirname $0)/config_7b_stage3.py \
28 | output_dir ${OUTPUT_DIR}
29 |
--------------------------------------------------------------------------------
/infty-VideoChat2/tasks/shared_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | from torch.utils.data import ConcatDataset, DataLoader
9 |
10 | from utils.optimizer import create_optimizer
11 | from utils.scheduler import create_scheduler
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def get_media_types(datasources):
17 | """get the media types for for all the dataloaders.
18 |
19 | Args:
20 | datasources (List): List of dataloaders or datasets.
21 |
22 | Returns: List. The media_types.
23 |
24 | """
25 | if isinstance(datasources[0], DataLoader):
26 | datasets = [dataloader.dataset for dataloader in datasources]
27 | else:
28 | datasets = datasources
29 | media_types = [
30 | dataset.datasets[0].media_type
31 | if isinstance(dataset, ConcatDataset)
32 | else dataset.media_type
33 | for dataset in datasets
34 | ]
35 |
36 | return media_types
37 |
38 |
39 | def setup_model(
40 | config, model_cls, find_unused_parameters=False
41 | ):
42 | logger.info("Creating model")
43 | config = copy.deepcopy(config)
44 |
45 | model = model_cls(config=config.model)
46 |
47 | model = model.to(torch.device(config.device))
48 | model_without_ddp = model
49 | if config.distributed:
50 | model = torch.nn.parallel.DistributedDataParallel(
51 | model,
52 | device_ids=[config.gpu],
53 | find_unused_parameters=find_unused_parameters, # `False` for image-only task
54 | )
55 |
56 | optimizer = create_optimizer(config.optimizer, model)
57 | scheduler = create_scheduler(config.scheduler, optimizer)
58 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
59 |
60 | start_epoch = 0
61 | global_step = 0
62 |
63 | # auto resume the latest checkpoint
64 | if config.get("auto_resume", False):
65 | logger.info("Auto resuming")
66 | model_latest = join(config.output_dir, "ckpt_latest.pth")
67 | model_best = join(config.output_dir, "ckpt_best.pth")
68 | large_num = -1
69 | for p in os.listdir(config.output_dir):
70 | if 'ckpt' in p:
71 | num = p.split('_')[1].split('.')[0]
72 | if str.isnumeric(num):
73 | if int(num) > large_num:
74 | large_num = int(num)
75 | if large_num != -1:
76 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
77 | if osp.isfile(model_latest):
78 | config.pretrained_path = model_latest
79 | config.resume = True
80 | elif osp.isfile(model_best):
81 | config.pretrained_path = model_best
82 | config.resume = True
83 | else:
84 | logger.info(f"Not found checkpoint in {config.output_dir}")
85 |
86 | if osp.isfile(config.pretrained_path):
87 | checkpoint = torch.load(config.pretrained_path, map_location="cpu")
88 | state_dict = checkpoint["model"]
89 |
90 | if config.resume:
91 | optimizer.load_state_dict(checkpoint["optimizer"])
92 | scheduler.load_state_dict(checkpoint["scheduler"])
93 | scaler.load_state_dict(checkpoint["scaler"])
94 | start_epoch = checkpoint["epoch"] + 1
95 | global_step = checkpoint["global_step"]
96 |
97 | msg = model_without_ddp.load_state_dict(state_dict, strict=False)
98 | logger.info(msg)
99 | logger.info(f"Loaded checkpoint from {config.pretrained_path}")
100 | else:
101 | logger.warning("No pretrained checkpoint provided, training from scratch")
102 |
103 | return (
104 | model,
105 | model_without_ddp,
106 | optimizer,
107 | scheduler,
108 | scaler,
109 | start_epoch,
110 | global_step,
111 | )
112 |
--------------------------------------------------------------------------------
/infty-VideoChat2/tasks/shared_utils_qformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | from torch.utils.data import ConcatDataset, DataLoader
9 |
10 | from models.bert.tokenization_bert import BertTokenizer
11 | from utils.optimizer import create_optimizer
12 | from utils.scheduler import create_scheduler
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def get_media_types(datasources):
18 | """get the media types for for all the dataloaders.
19 |
20 | Args:
21 | datasources (List): List of dataloaders or datasets.
22 |
23 | Returns: List. The media_types.
24 |
25 | """
26 | if isinstance(datasources[0], DataLoader):
27 | datasets = [dataloader.dataset for dataloader in datasources]
28 | else:
29 | datasets = datasources
30 | media_types = [
31 | dataset.datasets[0].media_type
32 | if isinstance(dataset, ConcatDataset)
33 | else dataset.media_type
34 | for dataset in datasets
35 | ]
36 |
37 | return media_types
38 |
39 |
40 | def setup_model(
41 | config, model_cls, find_unused_parameters=False
42 | ):
43 | logger.info("Creating model")
44 | config = copy.deepcopy(config)
45 |
46 | if "bert" in config.model.text_encoder.name:
47 | tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained, local_files_only=True)
48 | else:
49 | raise ValueError(f"Not supported text encoder.")
50 |
51 | model = model_cls(config=config, tokenizer=tokenizer)
52 |
53 | model = model.to(torch.device(config.device))
54 | model_without_ddp = model
55 | if config.distributed:
56 | model = torch.nn.parallel.DistributedDataParallel(
57 | model,
58 | device_ids=[config.gpu],
59 | find_unused_parameters=find_unused_parameters, # `False` for image-only task
60 | )
61 |
62 | optimizer = create_optimizer(config.optimizer, model)
63 | scheduler = create_scheduler(config.scheduler, optimizer)
64 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
65 |
66 | start_epoch = 0
67 | global_step = 0
68 |
69 | # auto resume the latest checkpoint
70 | if config.get("auto_resume", False):
71 | logger.info("Auto resuming")
72 | model_latest = join(config.output_dir, "ckpt_latest.pth")
73 | model_best = join(config.output_dir, "ckpt_best.pth")
74 | large_num = -1
75 | for p in os.listdir(config.output_dir):
76 | if 'ckpt' in p:
77 | num = p.split('_')[1].split('.')[0]
78 | if str.isnumeric(num):
79 | if int(num) > large_num:
80 | large_num = int(num)
81 | if large_num != -1:
82 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
83 | if osp.isfile(model_latest):
84 | config.pretrained_path = model_latest
85 | config.resume = True
86 | elif osp.isfile(model_best):
87 | config.pretrained_path = model_best
88 | config.resume = True
89 | else:
90 | logger.info(f"Not found checkpoint in {config.output_dir}")
91 |
92 | if osp.isfile(config.pretrained_path):
93 | checkpoint = torch.load(config.pretrained_path, map_location="cpu")
94 | state_dict = checkpoint["model"]
95 |
96 | if config.resume:
97 | optimizer.load_state_dict(checkpoint["optimizer"])
98 | scheduler.load_state_dict(checkpoint["scheduler"])
99 | scaler.load_state_dict(checkpoint["scaler"])
100 | start_epoch = checkpoint["epoch"] + 1
101 | global_step = checkpoint["global_step"]
102 |
103 |
104 | msg = model_without_ddp.load_state_dict(state_dict, strict=False)
105 | logger.info(msg)
106 | logger.info(f"Loaded checkpoint from {config.pretrained_path}")
107 | else:
108 | logger.warning("No pretrained checkpoint provided, training from scratch")
109 |
110 | return (
111 | model,
112 | model_without_ddp,
113 | optimizer,
114 | scheduler,
115 | scaler,
116 | tokenizer,
117 | start_epoch,
118 | global_step,
119 | )
120 |
--------------------------------------------------------------------------------
/infty-VideoChat2/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | import logging
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def setup_for_distributed(is_master):
11 | import warnings
12 |
13 | builtin_warn = warnings.warn
14 |
15 | def warn(*args, **kwargs):
16 | force = kwargs.pop("force", False)
17 | if is_master or force:
18 | builtin_warn(*args, **kwargs)
19 |
20 | # Log warnings only once
21 | warnings.warn = warn
22 | warnings.simplefilter("once", UserWarning)
23 |
24 | if not is_master:
25 | logging.disable()
26 |
27 |
28 | def is_dist_avail_and_initialized():
29 | if not dist.is_available():
30 | return False
31 | if not dist.is_initialized():
32 | return False
33 | return True
34 |
35 |
36 | def get_world_size():
37 | if not is_dist_avail_and_initialized():
38 | return 1
39 | return dist.get_world_size()
40 |
41 |
42 | def get_rank():
43 | if not is_dist_avail_and_initialized():
44 | return 0
45 | return dist.get_rank()
46 |
47 |
48 | def is_main_process():
49 | return get_rank() == 0
50 |
51 |
52 | def save_on_master(*args, **kwargs):
53 | if is_main_process():
54 | torch.save(*args, **kwargs)
55 |
56 |
57 | def is_port_in_use(port):
58 | import socket
59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60 | return s.connect_ex(('localhost', port)) == 0
61 |
62 |
63 | def init_distributed_mode(args):
64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
65 | # job started by torch.distributed.launch
66 | args.rank = int(os.environ["RANK"])
67 | args.world_size = int(os.environ['WORLD_SIZE'])
68 | args.gpu = int(os.environ['LOCAL_RANK'])
69 | elif 'SLURM_PROCID' in os.environ:
70 | # local rank on the current node / global rank
71 | local_rank = int(os.environ['SLURM_LOCALID'])
72 | global_rank = int(os.environ['SLURM_PROCID'])
73 | # number of processes / GPUs per node
74 | world_size = int(os.environ["SLURM_NNODES"]) * \
75 | int(os.environ["SLURM_TASKS_PER_NODE"][0])
76 |
77 | print(world_size)
78 |
79 | args.rank = global_rank
80 | args.gpu = local_rank
81 | args.world_size = world_size
82 | else:
83 | logger.info('Not using distributed mode')
84 | args.distributed = False
85 | return
86 |
87 | args.distributed = True
88 |
89 | torch.cuda.set_device(args.gpu)
90 | args.dist_backend = 'nccl'
91 |
92 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
93 | dist_port = int(args.dist_url.split(":")[-1])
94 | while is_port_in_use(dist_port):
95 | dist_port += 10
96 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
97 |
98 | logger.info('| distributed init (rank {}): {}'.format(
99 | args.rank, args.dist_url))
100 | if "SLURM_JOB_ID" in os.environ:
101 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
102 | torch.distributed.init_process_group(
103 | backend=args.dist_backend, init_method=args.dist_url,
104 | world_size=args.world_size, rank=args.rank)
105 | torch.distributed.barrier()
106 | setup_for_distributed(args.rank == 0)
107 |
108 |
109 | # Copyright (c) Facebook, Inc. and its affiliates.
110 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
111 | class GatherLayer(torch.autograd.Function):
112 | """
113 | Gather tensors from all workers with support for backward propagation:
114 | This implementation does not cut the gradients as torch.distributed.all_gather does.
115 | """
116 |
117 | @staticmethod
118 | def forward(ctx, x):
119 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
120 | dist.all_gather(output, x)
121 | return tuple(output)
122 |
123 | @staticmethod
124 | def backward(ctx, *grads):
125 | all_gradients = torch.stack(grads)
126 | dist.all_reduce(all_gradients)
127 | return all_gradients[dist.get_rank()]
128 |
129 |
130 | # copied from megavlt
131 | def gather_tensor_along_batch_with_backward(tensor, dim=0):
132 | world_size = get_world_size()
133 |
134 | if world_size < 2:
135 | return tensor
136 |
137 | tensor_list = GatherLayer.apply(tensor)
138 | tensor_list = torch.cat(tensor_list, dim=dim)
139 | return tensor_list
140 |
141 |
142 | @torch.no_grad()
143 | def gather_tensor_along_batch(tensor, dim=0):
144 | """
145 | Performs all_gather operation on the provided tensors.
146 | *** Warning ***: torch.distributed.all_gather has no gradient.
147 | """
148 | world_size = get_world_size()
149 |
150 | if world_size < 2:
151 | return tensor
152 |
153 | with torch.no_grad():
154 | tensor_list = []
155 |
156 | for _ in range(world_size):
157 | tensor_list.append(torch.zeros_like(tensor))
158 |
159 | dist.all_gather(tensor_list, tensor)
160 | tensor_list = torch.cat(tensor_list, dim=dim)
161 | return tensor_list
162 |
--------------------------------------------------------------------------------
/infty-VideoChat2/utils/easydict.py:
--------------------------------------------------------------------------------
1 | class EasyDict(dict):
2 | """
3 | Get attributes
4 |
5 | >>> d = EasyDict({'foo':3})
6 | >>> d['foo']
7 | 3
8 | >>> d.foo
9 | 3
10 | >>> d.bar
11 | Traceback (most recent call last):
12 | ...
13 | AttributeError: 'EasyDict' object has no attribute 'bar'
14 |
15 | Works recursively
16 |
17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
18 | >>> isinstance(d.bar, dict)
19 | True
20 | >>> d.bar.x
21 | 1
22 |
23 | Bullet-proof
24 |
25 | >>> EasyDict({})
26 | {}
27 | >>> EasyDict(d={})
28 | {}
29 | >>> EasyDict(None)
30 | {}
31 | >>> d = {'a': 1}
32 | >>> EasyDict(**d)
33 | {'a': 1}
34 |
35 | Set attributes
36 |
37 | >>> d = EasyDict()
38 | >>> d.foo = 3
39 | >>> d.foo
40 | 3
41 | >>> d.bar = {'prop': 'value'}
42 | >>> d.bar.prop
43 | 'value'
44 | >>> d
45 | {'foo': 3, 'bar': {'prop': 'value'}}
46 | >>> d.bar.prop = 'newer'
47 | >>> d.bar.prop
48 | 'newer'
49 |
50 |
51 | Values extraction
52 |
53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
54 | >>> isinstance(d.bar, list)
55 | True
56 | >>> from operator import attrgetter
57 | >>> map(attrgetter('x'), d.bar)
58 | [1, 3]
59 | >>> map(attrgetter('y'), d.bar)
60 | [2, 4]
61 | >>> d = EasyDict()
62 | >>> d.keys()
63 | []
64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
65 | >>> d.foo
66 | 3
67 | >>> d.bar.x
68 | 1
69 |
70 | Still like a dict though
71 |
72 | >>> o = EasyDict({'clean':True})
73 | >>> o.items()
74 | [('clean', True)]
75 |
76 | And like a class
77 |
78 | >>> class Flower(EasyDict):
79 | ... power = 1
80 | ...
81 | >>> f = Flower()
82 | >>> f.power
83 | 1
84 | >>> f = Flower({'height': 12})
85 | >>> f.height
86 | 12
87 | >>> f['power']
88 | 1
89 | >>> sorted(f.keys())
90 | ['height', 'power']
91 |
92 | update and pop items
93 | >>> d = EasyDict(a=1, b='2')
94 | >>> e = EasyDict(c=3.0, a=9.0)
95 | >>> d.update(e)
96 | >>> d.c
97 | 3.0
98 | >>> d['c']
99 | 3.0
100 | >>> d.get('c')
101 | 3.0
102 | >>> d.update(a=4, b=4)
103 | >>> d.b
104 | 4
105 | >>> d.pop('a')
106 | 4
107 | >>> d.a
108 | Traceback (most recent call last):
109 | ...
110 | AttributeError: 'EasyDict' object has no attribute 'a'
111 | """
112 |
113 | def __init__(self, d=None, **kwargs):
114 | if d is None:
115 | d = {}
116 | if kwargs:
117 | d.update(**kwargs)
118 | for k, v in d.items():
119 | setattr(self, k, v)
120 | # Class attributes
121 | for k in self.__class__.__dict__.keys():
122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
123 | setattr(self, k, getattr(self, k))
124 |
125 | def __setattr__(self, name, value):
126 | if isinstance(value, (list, tuple)):
127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
128 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
129 | value = self.__class__(value)
130 | super(EasyDict, self).__setattr__(name, value)
131 | super(EasyDict, self).__setitem__(name, value)
132 |
133 | __setitem__ = __setattr__
134 |
135 | def update(self, e=None, **f):
136 | d = e or dict()
137 | d.update(f)
138 | for k in d:
139 | setattr(self, k, d[k])
140 |
141 | def pop(self, k, d=None):
142 | if hasattr(self, k):
143 | delattr(self, k)
144 | return super(EasyDict, self).pop(k, d)
145 |
146 |
147 | if __name__ == "__main__":
148 | import doctest
149 |
150 |
--------------------------------------------------------------------------------
/infty-VideoChat2/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | """ Optimizer Factory w/ Custom Weight Decay
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import re
5 | import torch
6 | from torch import optim as optim
7 | from utils.distributed import is_main_process
8 | import logging
9 | logger = logging.getLogger(__name__)
10 | try:
11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
12 | has_apex = True
13 | except ImportError:
14 | has_apex = False
15 |
16 |
17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
18 | named_param_tuples = []
19 | for name, param in model.named_parameters():
20 | if not param.requires_grad:
21 | continue # frozen weights
22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
23 | named_param_tuples.append([name, param, 0])
24 | elif name in no_decay_list:
25 | named_param_tuples.append([name, param, 0])
26 | else:
27 | named_param_tuples.append([name, param, weight_decay])
28 | return named_param_tuples
29 |
30 |
31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
32 | """use lr=diff_lr for modules named found in diff_lr_names,
33 | otherwise use lr=default_lr
34 |
35 | Args:
36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
37 | diff_lr_names: List(str)
38 | diff_lr: float
39 | default_lr: float
40 | Returns:
41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr])
42 | """
43 | named_param_tuples_with_lr = []
44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
45 | for name, p, wd in named_param_tuples_or_model:
46 | use_diff_lr = False
47 | for diff_name in diff_lr_names:
48 | # if diff_name in name:
49 | if re.search(diff_name, name) is not None:
50 | logger.info(f"param {name} use different_lr: {diff_lr}")
51 | use_diff_lr = True
52 | break
53 |
54 | named_param_tuples_with_lr.append(
55 | [name, p, wd, diff_lr if use_diff_lr else default_lr]
56 | )
57 |
58 | if is_main_process():
59 | for name, _, wd, diff_lr in named_param_tuples_with_lr:
60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
61 |
62 | return named_param_tuples_with_lr
63 |
64 |
65 | def create_optimizer_params_group(named_param_tuples_with_lr):
66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
67 | group = {}
68 | for name, p, wd, lr in named_param_tuples_with_lr:
69 | if wd not in group:
70 | group[wd] = {}
71 | if lr not in group[wd]:
72 | group[wd][lr] = []
73 | group[wd][lr].append(p)
74 |
75 | optimizer_params_group = []
76 | for wd, lr_groups in group.items():
77 | for lr, p in lr_groups.items():
78 | optimizer_params_group.append(dict(
79 | params=p,
80 | weight_decay=wd,
81 | lr=lr
82 | ))
83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
84 | return optimizer_params_group
85 |
86 |
87 | def create_optimizer(args, model, filter_bias_and_bn=True, return_group=False):
88 | opt_lower = args.opt.lower()
89 | weight_decay = args.weight_decay
90 | # check for modules that requires different lr
91 | if hasattr(args, "different_lr") and args.different_lr.enable:
92 | diff_lr_module_names = args.different_lr.module_names
93 | diff_lr = args.different_lr.lr
94 | else:
95 | diff_lr_module_names = []
96 | diff_lr = None
97 |
98 | no_decay = {}
99 | if hasattr(model, 'no_weight_decay'):
100 | no_decay = model.no_weight_decay()
101 | named_param_tuples = add_weight_decay(
102 | model, weight_decay, no_decay, filter_bias_and_bn)
103 | named_param_tuples = add_different_lr(
104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
105 | parameters = create_optimizer_params_group(named_param_tuples)
106 |
107 | if return_group:
108 | return parameters
109 |
110 | if 'fused' in opt_lower:
111 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
112 |
113 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
114 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
115 | opt_args['eps'] = args.opt_eps
116 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
117 | opt_args['betas'] = args.opt_betas
118 | if hasattr(args, 'opt_args') and args.opt_args is not None:
119 | opt_args.update(args.opt_args)
120 |
121 | opt_split = opt_lower.split('_')
122 | opt_lower = opt_split[-1]
123 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
124 | opt_args.pop('eps', None)
125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
126 | elif opt_lower == 'momentum':
127 | opt_args.pop('eps', None)
128 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
129 | elif opt_lower == 'adam':
130 | optimizer = optim.Adam(parameters, **opt_args)
131 | elif opt_lower == 'adamw':
132 | optimizer = optim.AdamW(parameters, **opt_args)
133 | else:
134 | assert False and "Invalid optimizer"
135 | raise ValueError
136 | return optimizer
137 |
--------------------------------------------------------------------------------
/infty-VideoChat2/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from torch.optim import Optimizer
5 | import math
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 |
9 | def create_scheduler(args, optimizer):
10 | lr_scheduler = None
11 | if args.sched == 'cosine':
12 | lr_scheduler = get_cosine_schedule_with_warmup(
13 | optimizer,
14 | num_warmup_steps=args.num_warmup_steps,
15 | num_training_steps=args.num_training_steps,
16 | num_cycles=0.5,
17 | min_lr_multi=args.min_lr_multi
18 | )
19 | return lr_scheduler
20 |
21 |
22 | def get_cosine_schedule_with_warmup(
23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
25 | ):
26 | """
27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
28 |
29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
31 | initial lr set in the optimizer.
32 | Args:
33 | optimizer ([`~torch.optim.Optimizer`]):
34 | The optimizer for which to schedule the learning rate.
35 | num_warmup_steps (`int`):
36 | The number of steps for the warmup phase.
37 | num_training_steps (`int`):
38 | The total number of training steps.
39 | num_cycles (`float`, *optional*, defaults to 0.5):
40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
41 | following a half-cosine).
42 | min_lr_multi (`float`, *optional*, defaults to 0):
43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
44 | last_epoch (`int`, *optional*, defaults to -1):
45 | The index of the last epoch when resuming training.
46 | Return:
47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
48 | """
49 |
50 | def lr_lambda(current_step):
51 | if current_step < num_warmup_steps:
52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
55 |
56 | return LambdaLR(optimizer, lr_lambda, last_epoch)
57 |
--------------------------------------------------------------------------------