', 'person')
30 |
31 | text = re.sub(r"\s{2,}", ' ', text)
32 | text = text.rstrip('\n').strip(' ')
33 |
34 | if max_l: # truncate
35 | words = text.split(' ')
36 | if len(words) > max_l:
37 | text = ' '.join(words[:max_l])
38 | else:
39 | pass
40 | return text
41 |
42 |
--------------------------------------------------------------------------------
/dataset/video_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
3 | """
4 | import random
5 | import io
6 | import os
7 | import av
8 | import cv2
9 | import decord
10 | import imageio
11 | from decord import VideoReader
12 |
13 | # from dataloader import KVReader
14 | import torch
15 | import numpy as np
16 | import math
17 | # import tensorflow as tf
18 | decord.bridge.set_bridge("torch")
19 |
20 | import logging
21 | logger = logging.getLogger(__name__)
22 |
23 | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
24 | """
25 | Converts a present time with the given time base and start_pts offset to seconds.
26 |
27 | Returns:
28 | time_in_seconds (float): The corresponding time in seconds.
29 |
30 | https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
31 | """
32 | if pts == math.inf:
33 | return math.inf
34 |
35 | return int(pts - start_pts) * time_base
36 |
37 |
38 | def get_pyav_video_duration(video_reader):
39 | video_stream = video_reader.streams.video[0]
40 | video_duration = pts_to_secs(
41 | video_stream.duration,
42 | video_stream.time_base,
43 | video_stream.start_time
44 | )
45 | return float(video_duration)
46 |
47 |
48 | def get_frame_indices_by_fps():
49 | pass
50 |
51 |
52 | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
53 | if sample in ["rand", "middle"]: # uniform sampling
54 | acc_samples = min(num_frames, vlen)
55 | # split the video into `acc_samples` intervals, and sample from each interval.
56 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
57 | ranges = []
58 | for idx, interv in enumerate(intervals[:-1]):
59 | ranges.append((interv, intervals[idx + 1] - 1))
60 | if sample == 'rand':
61 | try:
62 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
63 | except:
64 | frame_indices = np.random.permutation(vlen)[:acc_samples]
65 | frame_indices.sort()
66 | frame_indices = list(frame_indices)
67 | elif fix_start is not None:
68 | frame_indices = [x[0] + fix_start for x in ranges]
69 | elif sample == 'middle':
70 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
71 | else:
72 | raise NotImplementedError
73 |
74 | if len(frame_indices) < num_frames: # padded with last frame
75 | padded_frame_indices = [frame_indices[-1]] * num_frames
76 | padded_frame_indices[:len(frame_indices)] = frame_indices
77 | frame_indices = padded_frame_indices
78 | elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
79 | output_fps = float(sample[3:])
80 | duration = float(vlen) / input_fps
81 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
82 | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
83 | frame_indices = np.around(frame_seconds * input_fps).astype(int)
84 | frame_indices = [e for e in frame_indices if e < vlen]
85 | if max_num_frames > 0 and len(frame_indices) > max_num_frames:
86 | frame_indices = frame_indices[:max_num_frames]
87 | # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
88 | else:
89 | raise ValueError
90 | return frame_indices
91 |
92 |
93 | def read_frames_av(
94 | video_path, num_frames, sample='rand', fix_start=None,
95 | max_num_frames=-1, client=None, clip=None,
96 | ):
97 | reader = av.open(video_path)
98 | frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
99 | vlen = len(frames)
100 | duration = get_pyav_video_duration(reader)
101 | fps = vlen / float(duration)
102 | frame_indices = get_frame_indices(
103 | num_frames, vlen, sample=sample, fix_start=fix_start,
104 | input_fps=fps, max_num_frames=max_num_frames
105 | )
106 | frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
107 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
108 | return frames, frame_indices, fps
109 |
110 |
111 | def read_frames_gif(
112 | video_path, num_frames, sample='rand', fix_start=None,
113 | max_num_frames=-1, client=None, clip=None,
114 | ):
115 | if video_path.startswith('s3') or video_path.startswith('p2'):
116 | video_bytes = client.get(video_path)
117 | gif = imageio.get_reader(io.BytesIO(video_bytes))
118 | else:
119 | gif = imageio.get_reader(video_path)
120 | vlen = len(gif)
121 | frame_indices = get_frame_indices(
122 | num_frames, vlen, sample=sample, fix_start=fix_start,
123 | max_num_frames=max_num_frames
124 | )
125 | frames = []
126 | for index, frame in enumerate(gif):
127 | # for index in frame_idxs:
128 | if index in frame_indices:
129 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
130 | frame = torch.from_numpy(frame).byte()
131 | # # (H x W x C) to (C x H x W)
132 | frame = frame.permute(2, 0, 1)
133 | frames.append(frame)
134 | frames = torch.stack(frames) # .float() / 255
135 |
136 | return frames, frame_indices, 25. # for tgif
137 |
138 |
139 | def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None,
140 | max_num_frames=-1, client=None, clip=None):
141 | _context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)}
142 | _sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)}
143 | num_parallel_reader = 1
144 | filename, extension = os.path.splitext(ind_file)
145 | reader = KVReader(filename, num_parallel_reader)
146 | key = vid
147 | values = reader.read_many([key])
148 | item = values[0]
149 | contexts, sequences = tf.io.parse_single_sequence_example(
150 | serialized=item,
151 | context_features=_context_features,
152 | sequence_features=_sequence_features)
153 |
154 | # text = contexts['title'].numpy().decode("utf-8")
155 | rawframes = sequences['data']
156 | vlen = len(rawframes)
157 | sample="rand"
158 |
159 | frame_indices = get_frame_indices(num_frames, vlen, sample=sample,
160 | fix_start=fix_start,
161 | max_num_frames=max_num_frames)
162 | def read_image(raw_data):
163 | return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy()
164 |
165 | frames = []
166 | for index, frame in enumerate(rawframes):
167 | if index in frame_indices:
168 | frame = read_image(frame)
169 | frame = torch.as_tensor(frame)
170 | frames.append(frame)
171 |
172 | frames = torch.stack(frames)
173 | # print("in hdfs========>",frames[0])
174 | frames = frames.permute(0, 3, 1, 2)
175 | return frames, frame_indices, 25 # don't know the fps for index
176 |
177 |
178 | def read_frames_decord(
179 | video_path, num_frames, sample='rand', fix_start=None,
180 | max_num_frames=-1, client=None, clip=None
181 | ):
182 | if video_path.startswith('s3') or video_path.startswith('p2'):
183 | video_bytes = client.get(video_path)
184 | video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
185 | else:
186 | video_reader = VideoReader(video_path, num_threads=1)
187 | vlen = len(video_reader)
188 | fps = video_reader.get_avg_fps()
189 | duration = vlen / float(fps)
190 |
191 | if clip:
192 | start, end = clip
193 | duration = end - start
194 | vlen = int(duration * fps)
195 | start_index = int(start * fps)
196 |
197 | frame_indices = get_frame_indices(
198 | num_frames, vlen, sample=sample, fix_start=fix_start,
199 | input_fps=fps, max_num_frames=max_num_frames
200 | )
201 | if clip:
202 | frame_indices = [f + start_index for f in frame_indices]
203 |
204 | frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
205 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
206 | return frames, frame_indices, float(fps)
207 |
208 |
209 | VIDEO_READER_FUNCS = {
210 | 'av': read_frames_av,
211 | 'decord': read_frames_decord,
212 | 'gif': read_frames_gif,
213 | 'hdfs': read_frames_hdfs,
214 | }
215 |
--------------------------------------------------------------------------------
/evaluate_egoschema_result.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | import argparse
3 | import requests
4 |
5 | root_dir = 'test_results/pllava-7b-lora14-threshold0.8-layer10-alpha0.4-temporal-segment-ratio-0.25-cluster-ratio-0.5/egoschema'
6 |
7 | def extract_and_convert(label_string):
8 | # 创建一个字典来映射字母到数字
9 | mapping = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
10 |
11 | # 提取字符串中的第一个字符
12 | first_char = label_string[1]
13 |
14 | # 确保字符在映射范围内
15 | if first_char in mapping:
16 | return mapping[first_char]
17 | else:
18 | raise ValueError("Input string does not start with a valid label (A-E).")
19 |
20 | def send_post_request(data):
21 | """
22 | Sends a POST request to the specified URL with the given JSON file.
23 |
24 | Parameters:
25 | - json_file (str): Path to the JSON file to be used in the request body.
26 |
27 | Returns:
28 | - Response object containing server's response.
29 | """
30 |
31 | url = "https://validation-server.onrender.com/api/upload/"
32 | headers = {
33 | "Content-Type": "application/json"
34 | }
35 |
36 | response = requests.post(url, headers=headers, json=data)
37 |
38 | return response
39 |
40 | predition_jsonls = [f for f in os.listdir(root_dir) if 'all_results' in f]
41 |
42 | result_dict = {}
43 |
44 | for pred_jsonl in predition_jsonls:
45 | data_list = json.load(open(os.path.join(root_dir, pred_jsonl), 'r'))['result_list']
46 | for data in data_list:
47 | pred = data['pred']
48 | pred = extract_and_convert(pred)
49 | vid = data['video_path'].split('/')[-1].split('.')[0]
50 | result_dict[vid] = pred
51 | # with open(os.path.join(root_dir, pred_jsonl), 'r') as f:
52 | # lines = f.readlines()
53 | # for line in lines:
54 | # data = json.loads(line)
55 | # result_dict[data['vid']] = extract_and_convert(data['text']['prediction'])
56 | print(result_dict)
57 | response = send_post_request(result_dict)
58 | print(f"Response Status Code: {response.status_code}")
59 | print(f"Response Content:\n{response.text}")
--------------------------------------------------------------------------------
/example/1917.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/1917.mov
--------------------------------------------------------------------------------
/example/1917.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/1917.mp4
--------------------------------------------------------------------------------
/example/bear.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/bear.jpg
--------------------------------------------------------------------------------
/example/cooking.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/cooking.mp4
--------------------------------------------------------------------------------
/example/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/dog.png
--------------------------------------------------------------------------------
/example/jesse_dance.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/jesse_dance.mp4
--------------------------------------------------------------------------------
/example/working.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/working.mp4
--------------------------------------------------------------------------------
/example/yoga.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/yoga.mp4
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/__init__.py
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
17 |
18 | # from .modeling_pllava_flow import PllavaFlowForConditionalGeneration
19 | from .modeling_pllava import PllavaForConditionalGeneration
20 | from .modeling_pllava_SF import PllavaSFForConditionalGeneration
21 | from .processing_pllava import PllavaProcessor
22 | from .configuration_pllava import PllavaConfig
23 |
24 | # _import_structure = {"configuration_pllava": ["PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "PllavaConfig"]}
25 |
26 | # try:
27 | # if not is_torch_available():
28 | # raise OptionalDependencyNotAvailable()
29 | # except OptionalDependencyNotAvailable:
30 | # pass
31 | # else:
32 | # _import_structure["modeling_pllava"] = [
33 | # "PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
34 | # "PllavaForConditionalGeneration",
35 | # "PllavaPreTrainedModel",
36 | # ]
37 | # _import_structure["processing_pllava"] = ["PllavaProcessor"]
38 |
39 |
40 | # if TYPE_CHECKING:
41 | # from .configuration_pllava import PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, PllavaConfig
42 |
43 | # try:
44 | # if not is_torch_available():
45 | # raise OptionalDependencyNotAvailable()
46 | # except OptionalDependencyNotAvailable:
47 | # pass
48 | # else:
49 | # from .modeling_pllava import (
50 | # PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
51 | # PllavaForConditionalGeneration,
52 | # PllavaPreTrainedModel,
53 | # )
54 | # from .processing_pllava import PllavaProcessor
55 |
56 |
57 | # else:
58 | # import sys
59 |
60 | # sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
61 |
--------------------------------------------------------------------------------
/models/pllava/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/configuration_pllava.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/configuration_pllava.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/elastic_cache.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/elastic_cache.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/llama.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/llama.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/modeling_clip.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_clip.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/modeling_flash_attention_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_flash_attention_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/modeling_pllava.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_pllava.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/modeling_pllava_SF.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_pllava_SF.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/modeling_pllava_flow.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_pllava_flow.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/modify_llama.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modify_llama.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/processing_pllava.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/processing_pllava.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/__pycache__/v433_modeling_llama.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/v433_modeling_llama.cpython-310.pyc
--------------------------------------------------------------------------------
/models/pllava/configuration_pllava.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ Llava model configuration"""
15 |
16 | from transformers.configuration_utils import PretrainedConfig
17 | from transformers.utils import logging
18 | from transformers.models.auto import CONFIG_MAPPING
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24 | "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
25 | }
26 |
27 |
28 | class PllavaConfig(PretrainedConfig):
29 | r"""
30 | This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
31 | Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
32 | with the defaults will yield a similar configuration to that of the Llava-9B.
33 |
34 | e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
35 |
36 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37 | documentation from [`PretrainedConfig`] for more information.
38 |
39 | Args:
40 | vision_config (`LlavaVisionConfig`, *optional*):
41 | Custom vision config or dict
42 | text_config (`Union[AutoConfig, dict]`, *optional*):
43 | The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
44 | ignore_index (`int`, *optional*, defaults to -100):
45 | The ignore index for the loss function.
46 | image_token_index (`int`, *optional*, defaults to 32000):
47 | The image token index to encode the image prompt.
48 | projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
49 | The activation function used by the multimodal projector.
50 | vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
51 | The feature selection strategy used to select the vision feature from the CLIP backbone.
52 | vision_feature_layer (`int`, *optional*, defaults to -2):
53 | The index of the layer to select the vision feature.
54 | vocab_size (`int`, *optional*, defaults to 32000):
55 | Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
56 | `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
57 |
58 | Example:
59 |
60 | ```python
61 | >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
62 |
63 | >>> # Initializing a CLIP-vision config
64 | >>> vision_config = CLIPVisionConfig()
65 |
66 | >>> # Initializing a Llama config
67 | >>> text_config = LlamaConfig()
68 |
69 | >>> # Initializing a Llava llava-1.5-7b style configuration
70 | >>> configuration = LlavaConfig(vision_config, text_config)
71 |
72 | >>> # Initializing a model from the llava-1.5-7b style configuration
73 | >>> model = LlavaForConditionalGeneration(configuration)
74 |
75 | >>> # Accessing the model configuration
76 | >>> configuration = model.config
77 | ```"""
78 |
79 | model_type = "llava"
80 | is_composition = False
81 |
82 | def __init__(
83 | self,
84 | vision_config=None,
85 | text_config=None,
86 | ignore_index=-100,
87 | image_token_index=32000,
88 | projector_hidden_act="gelu",
89 | vision_feature_select_strategy="default",
90 | vision_feature_layer=-2,
91 | vocab_size=32000,
92 | pooling_method='avg',
93 | pooling_shape=(8, 16, 16),
94 | frame_shape=(24, 24), # llava 1.5 pretrained frame shape
95 | num_frames=1, # llava 1.5 pretrained frame shape
96 | use_pooling=True,
97 | gradient_checkpointing=False,
98 | selected_layer=10,
99 | alpha=0.1,
100 | head=0,
101 | softmax=1.0,
102 | tau=1.0,
103 | cluster_ratio=1.0,
104 | temporal_segment_ratio=1.0,
105 | **kwargs,
106 | ):
107 | self.ignore_index = ignore_index
108 | self.image_token_index = image_token_index
109 | self.projector_hidden_act = projector_hidden_act
110 | self.vision_feature_select_strategy = vision_feature_select_strategy
111 | self.vision_feature_layer = vision_feature_layer
112 | self.vocab_size = vocab_size
113 | self.use_pooling = use_pooling
114 | self.gradient_checkpointing = gradient_checkpointing
115 | self.selected_layer = selected_layer
116 | self.alpha = alpha
117 | self.head = head
118 | self.softmax = softmax
119 | self.tau = tau
120 | self.cluster_ratio = cluster_ratio
121 | self.temporal_segment_ratio = temporal_segment_ratio
122 |
123 | self.vision_config = vision_config
124 |
125 | self.pooling_method = pooling_method # should be in 'max', 'avg'
126 | self.pooling_shape = pooling_shape #
127 | self.frame_shape = frame_shape #
128 | self.num_frames = num_frames
129 | if isinstance(self.vision_config, dict):
130 | vision_config["model_type"] = (
131 | vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
132 | )
133 | self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
134 | elif vision_config is None:
135 | self.vision_config = CONFIG_MAPPING["clip_vision_model"](
136 | intermediate_size=4096,
137 | hidden_size=1024,
138 | patch_size=14,
139 | image_size=336,
140 | num_hidden_layers=24,
141 | num_attention_heads=16,
142 | vocab_size=32000,
143 | projection_dim=768,
144 | )
145 | self.vocab_size = self.vocab_size
146 |
147 | self.text_config = text_config
148 |
149 | if isinstance(self.text_config, dict):
150 | text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
151 | self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
152 | self.vocab_size = self.text_config.vocab_size
153 | self.text_config.gradient_checkpointing = self.gradient_checkpointing
154 |
155 | elif text_config is None:
156 | tmp_config = {"_attn_implementation":"flash_attention_2",
157 | "gradient_checkpointing": self.gradient_checkpointing}
158 | self.text_config = CONFIG_MAPPING["llama"](**tmp_config)
159 | self.text_config.gradient_checkpointing = self.gradient_checkpointing
160 | # self.text_config["_attn_implementation"]="flash_attention_2" # xl: temporal hard code
161 |
162 |
163 | super().__init__(**kwargs)
164 |
--------------------------------------------------------------------------------
/models/pllava/convert_pllava_weights_to_hf.py:
--------------------------------------------------------------------------------
1 | # Not yet
--------------------------------------------------------------------------------
/models/pllava/modeling_flash_attention_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import inspect
17 | import os
18 | from typing import Optional, Tuple
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 | from transformers.utils import is_flash_attn_2_available
24 |
25 | if is_flash_attn_2_available():
26 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
27 | from flash_attn import flash_attn_func, flash_attn_varlen_func
28 |
29 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
30 |
31 |
32 | def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
33 | """
34 | Retrieves indexing data required to repad unpadded (ragged) tensors.
35 |
36 | Arguments:
37 | attention_mask (`torch.Tensor`):
38 | Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
39 |
40 | Return:
41 | indices (`torch.Tensor):
42 | The indices of non-masked tokens from the flattened input sequence.
43 | cu_seqlens (`torch.Tensor`):
44 | The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
45 | max_seqlen_in_batch (`int`):
46 | Maximum sequence length in batch.
47 | """
48 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
49 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
50 | max_seqlen_in_batch = seqlens_in_batch.max().item()
51 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
52 | return (
53 | indices,
54 | cu_seqlens,
55 | max_seqlen_in_batch,
56 | )
57 |
58 |
59 | def _upad_input(
60 | query_layer: torch.Tensor,
61 | key_layer: torch.Tensor,
62 | value_layer: torch.Tensor,
63 | attention_mask: torch.Tensor,
64 | query_length: int,
65 | ):
66 | """
67 | Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
68 |
69 | This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
70 | tensors for query, key, value tensors.
71 |
72 | Arguments:
73 | query_layer (`torch.Tensor`):
74 | Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
75 | key_layer (`torch.Tensor`):
76 | Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
77 | value_layer (`torch.Tensor`):
78 | Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
79 | attention_mask (`torch.Tensor`):
80 | Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
81 | query_length (`int`):
82 | Target length.
83 |
84 | Return:
85 | query_layer (`torch.Tensor):
86 | Query state without padding. Shape: (total_target_length, num_heads, head_dim).
87 | key_layer (`torch.Tensor`):
88 | Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
89 | value_layer (`torch.Tensor`):
90 | Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
91 | indices_q (`torch.Tensor`):
92 | The indices of non-masked tokens from the flattened input target sequence.
93 | (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
94 | The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
95 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
96 | Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
97 | """
98 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
99 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
100 |
101 | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
102 | value_layer = index_first_axis(
103 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
104 | )
105 | if query_length == kv_seq_len:
106 | query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
107 | cu_seqlens_q = cu_seqlens_k
108 | max_seqlen_in_batch_q = max_seqlen_in_batch_k
109 | indices_q = indices_k
110 | elif query_length == 1:
111 | max_seqlen_in_batch_q = 1
112 | cu_seqlens_q = torch.arange(
113 | batch_size + 1, dtype=torch.int32, device=query_layer.device
114 | ) # There is a memcpy here, that is very bad.
115 | indices_q = cu_seqlens_q[:-1]
116 | query_layer = query_layer.squeeze(1)
117 | else:
118 | # The -q_len: slice assumes left padding.
119 | attention_mask = attention_mask[:, -query_length:]
120 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
121 |
122 | return (
123 | query_layer,
124 | key_layer,
125 | value_layer,
126 | indices_q,
127 | (cu_seqlens_q, cu_seqlens_k),
128 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
129 | )
130 |
131 |
132 | def _flash_attention_forward(
133 | query_states: torch.Tensor,
134 | key_states: torch.Tensor,
135 | value_states: torch.Tensor,
136 | attention_mask: torch.Tensor,
137 | query_length: int,
138 | is_causal: bool,
139 | dropout: float = 0.0,
140 | softmax_scale: Optional[float] = None,
141 | sliding_window: Optional[int] = None,
142 | use_top_left_mask: bool = False,
143 | softcap: Optional[float] = None,
144 | deterministic: bool = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1",
145 | ):
146 | """
147 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
148 | first unpad the input, then computes the attention scores and pad the final attention scores.
149 |
150 | Args:
151 | query_states (`torch.Tensor`):
152 | Input query states to be passed to Flash Attention API
153 | key_states (`torch.Tensor`):
154 | Input key states to be passed to Flash Attention API
155 | value_states (`torch.Tensor`):
156 | Input value states to be passed to Flash Attention API
157 | attention_mask (`torch.Tensor`):
158 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
159 | position of padding tokens and 1 for the position of non-padding tokens.
160 | dropout (`float`):
161 | Attention dropout
162 | softmax_scale (`float`, *optional*):
163 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
164 | use_top_left_mask (`bool`, defaults to `False`):
165 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
166 | softcap (`float`, *optional*):
167 | Softcap for the attention logits, used e.g. in gemma2.
168 | deterministic (`bool`, *optional*):
169 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
170 | """
171 | if not use_top_left_mask:
172 | causal = is_causal
173 | else:
174 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
175 | causal = is_causal and query_length != 1
176 |
177 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
178 | use_sliding_windows = (
179 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
180 | )
181 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
182 |
183 | if is_flash_attn_greater_or_equal("2.4.1"):
184 | flash_kwargs["deterministic"] = deterministic
185 |
186 | if softcap is not None:
187 | flash_kwargs["softcap"] = softcap
188 |
189 | # Contains at least one padding token in the sequence
190 | if attention_mask is not None:
191 | batch_size = query_states.shape[0]
192 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
193 | query_states, key_states, value_states, attention_mask, query_length
194 | )
195 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
196 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
197 |
198 | attn_output_unpad = flash_attn_varlen_func(
199 | query_states,
200 | key_states,
201 | value_states,
202 | cu_seqlens_q=cu_seqlens_q,
203 | cu_seqlens_k=cu_seqlens_k,
204 | max_seqlen_q=max_seqlen_in_batch_q,
205 | max_seqlen_k=max_seqlen_in_batch_k,
206 | dropout_p=dropout,
207 | softmax_scale=softmax_scale,
208 | causal=causal,
209 | **flash_kwargs,
210 | )
211 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
212 | else:
213 | attn_output = flash_attn_func(
214 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
215 | )
216 |
217 | return attn_output
--------------------------------------------------------------------------------
/requirements.no_torch.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.26.1
3 | addict==2.4.0
4 | aiofiles==23.2.1
5 | aliyun-python-sdk-core==2.15.0
6 | aliyun-python-sdk-kms==2.16.2
7 | altair==5.2.0
8 | annotated-types==0.6.0
9 | antlr4-python3-runtime==4.9.3
10 | anyio==4.3.0
11 | anykeystore==0.2
12 | apex==0.9.10.dev0
13 | appdirs==1.4.4
14 | argcomplete==3.2.3
15 | attrs==23.2.0
16 | av==10.0.0
17 | beautifulsoup4==4.12.3
18 | blessed==1.20.0
19 | blessings==1.7
20 | boto3==1.34.63
21 | botocore==1.34.63
22 | Brotli==1.1.0
23 | cachetools==5.3.3
24 | certifi==2024.2.2
25 | cffi==1.16.0
26 | charset-normalizer==3.3.2
27 | click==8.1.7
28 | colorama==0.4.6
29 | contourpy==1.2.0
30 | crcmod==1.7
31 | cryptacular==1.6.2
32 | cryptography==42.0.5
33 | cycler==0.12.1
34 | dacite==1.7.0
35 | decorator==4.4.2
36 | decord==0.6.0
37 | deepspeed==0.14.0
38 | defusedxml==0.7.1
39 | Deprecated==1.2.14
40 | dill==0.3.8
41 | distro==1.9.0
42 | dnspython==2.6.1
43 | docker-pycreds==0.4.0
44 | einops==0.6.1
45 | exceptiongroup==1.2.0
46 | fastapi==0.110.0
47 | ffmpeg==1.4
48 | ffmpy==0.3.2
49 | fiftyone==0.23.6
50 | fiftyone-brain==0.16.1
51 | fiftyone_db==1.1.2
52 | filelock==3.9.0
53 | flash-attn==2.5.6
54 | fonttools==4.49.0
55 | fsspec==2024.2.0
56 | ftfy==6.1.3
57 | future==1.0.0
58 | fvcore==0.1.5.post20221221
59 | gdown==5.1.0
60 | gitdb==4.0.11
61 | GitPython==3.1.42
62 | glob2==0.7
63 | google-auth==2.28.2
64 | google-auth-oauthlib==1.2.0
65 | gpustat==1.1.1
66 | gradio==4.21.0
67 | gradio_client==0.12.0
68 | graphql-core==3.2.3
69 | greenlet==3.0.3
70 | grpcio==1.62.1
71 | h11==0.14.0
72 | h2==4.1.0
73 | hjson==3.1.0
74 | hpack==4.0.0
75 | httpcore==1.0.4
76 | httpx==0.27.0
77 | huggingface-hub==0.21.4
78 | humanize==4.9.0
79 | hupper==1.12.1
80 | Hypercorn==0.16.0
81 | hyperframe==6.0.1
82 | idna==3.6
83 | idscheck==2.3.0
84 | imageio==2.27.0
85 | imageio-ffmpeg==0.4.9
86 | importlib_metadata==7.0.2
87 | importlib_resources==6.3.0
88 | inflate64==1.0.0
89 | iopath==0.1.10
90 | Jinja2==3.1.2
91 | jmespath==0.10.0
92 | joblib==1.3.2
93 | jsonlines==4.0.0
94 | jsonschema==4.21.1
95 | jsonschema-specifications==2023.12.1
96 | kaleido==0.2.1
97 | kiwisolver==1.4.5
98 | lazy_loader==0.3
99 | Markdown==3.6
100 | markdown-it-py==3.0.0
101 | MarkupSafe==2.1.3
102 | matplotlib==3.8.3
103 | mdurl==0.1.2
104 | mmcv-full==1.7.2
105 | model-index==0.1.11
106 | mongoengine==0.24.2
107 | motor==3.3.2
108 | moviepy==1.0.3
109 | mpmath==1.3.0
110 | multivolumefile==0.2.3
111 | networkx==3.2.1
112 | ninja==1.11.1.1
113 | numpy
114 | oauthlib==3.2.2
115 | omegaconf==2.3.0
116 | openai==1.14.0
117 | opencv-python==4.9.0.80
118 | opencv-python-headless==4.9.0.80
119 | opendatalab==0.0.10
120 | openmim==0.3.9
121 | openxlab==0.0.36
122 | ordered-set==4.1.0
123 | orjson==3.9.15
124 | oss2==2.17.0
125 | packaging==24.0
126 | pandas==1.5.3
127 | PasteDeploy==3.1.0
128 | pathtools==0.1.2
129 | pbkdf2==1.3
130 | peft==0.10.0
131 | pillow==10.2.0
132 | plaster==1.1.2
133 | plaster-pastedeploy==1.0.1
134 | platformdirs==4.2.0
135 | plotly==5.20.0
136 | portalocker==2.8.2
137 | pprintpp==0.4.0
138 | priority==2.0.0
139 | proglog==0.1.10
140 | protobuf==4.23.4
141 | psutil==5.9.4
142 | py-cpuinfo==9.0.0
143 | py7zr==0.21.0
144 | pyasn1==0.5.1
145 | pyasn1-modules==0.3.0
146 | pybcj==1.0.2
147 | pycparser==2.21
148 | pycryptodome==3.20.0
149 | pycryptodomex==3.20.0
150 | pydantic==2.6.4
151 | pydantic_core==2.16.3
152 | pydub==0.25.1
153 | Pygments==2.17.2
154 | pymongo==4.6.2
155 | pynvml==11.5.0
156 | pyparsing==3.1.2
157 | pyppmd==1.1.0
158 | pyramid==2.0.2
159 | pyramid-mailer==0.15.1
160 | PySocks==1.7.1
161 | python-dateutil==2.9.0.post0
162 | python-multipart==0.0.9
163 | python3-openid==3.2.0
164 | pytz==2023.4
165 | PyYAML==6.0
166 | pyzstd==0.15.9
167 | rarfile==4.1
168 | referencing==0.33.0
169 | regex==2023.12.25
170 | repoze.sendmail==4.4.1
171 | requests==2.28.2
172 | requests-oauthlib==1.4.0
173 | retrying==1.3.4
174 | rich==13.4.2
175 | rpds-py==0.18.0
176 | rsa==4.9
177 | ruff==0.3.2
178 | s3transfer==0.10.1
179 | safetensors==0.4.2
180 | scikit-image==0.22.0
181 | scikit-learn==1.4.1.post1
182 | scipy==1.10.1
183 | semantic-version==2.10.0
184 | sentencepiece==0.2.0
185 | sentry-sdk==1.42.0
186 | setproctitle==1.3.3
187 | shellingham==1.5.4
188 | six==1.16.0
189 | smmap==5.0.1
190 | sniffio==1.3.1
191 | sortedcontainers==2.4.0
192 | soupsieve==2.5
193 | SQLAlchemy==2.0.28
194 | sse-starlette==0.10.3
195 | sseclient-py==1.8.0
196 | starlette==0.36.3
197 | strawberry-graphql==0.138.1
198 | sympy==1.12
199 | tabulate==0.9.0
200 | taskgroup==0.0.0a4
201 | tenacity==8.2.3
202 | tensorboard==2.15.1
203 | tensorboard-data-server==0.7.2
204 | tensorboardX==2.6.2.2
205 | termcolor==2.3.0
206 | texttable==1.7.0
207 | threadpoolctl==3.3.0
208 | tifffile==2024.2.12
209 | timm==0.6.12
210 | tokenizers==0.15.2
211 | tomli==2.0.1
212 | tomlkit==0.12.0
213 | toolz==0.12.1
214 | tqdm==4.65.2
215 | transaction==4.0
216 | transformers==4.37.1
217 | translationstring==1.4
218 | triton==2.2.0
219 | typer==0.9.0
220 | typing_extensions==4.8.0
221 | tzdata==2024.1
222 | tzlocal==5.2
223 | universal-analytics-python3==1.1.1
224 | urllib3==1.26.18
225 | uvicorn==0.28.0
226 | velruse==1.1.1
227 | venusian==3.1.0
228 | voxel51-eta==0.12.6
229 | wandb==0.14.0
230 | wcwidth==0.2.13
231 | WebOb==1.8.7
232 | websockets==11.0.3
233 | Werkzeug==3.0.1
234 | wrapt==1.16.0
235 | wsproto==1.2.0
236 | WTForms==3.1.2
237 | wtforms-recaptcha==0.3.2
238 | xmltodict==0.13.0
239 | yacs==0.1.8
240 | yapf==0.40.2
241 | zipp==3.18.1
242 | zope.deprecation==5.0
243 | zope.interface==6.2
244 | zope.sqlalchemy==3.1
245 |
--------------------------------------------------------------------------------
/requirements.torch.txt:
--------------------------------------------------------------------------------
1 | --index-url https://download.pytorch.org/whl/cu118
2 | torch==2.2.1
3 | torchaudio==2.2.1
4 | torchvision==0.17.1
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.26.1
3 | addict==2.4.0
4 | aiofiles==23.2.1
5 | aliyun-python-sdk-core==2.15.0
6 | aliyun-python-sdk-kms==2.16.2
7 | altair==5.2.0
8 | annotated-types==0.6.0
9 | antlr4-python3-runtime==4.9.3
10 | anyio==4.3.0
11 | anykeystore==0.2
12 | apex==0.9.10.dev0
13 | appdirs==1.4.4
14 | argcomplete==3.2.3
15 | attrs==23.2.0
16 | av==10.0.0
17 | beautifulsoup4==4.12.3
18 | blessed==1.20.0
19 | blessings==1.7
20 | boto3==1.34.63
21 | botocore==1.34.63
22 | Brotli==1.1.0
23 | cachetools==5.3.3
24 | certifi==2024.2.2
25 | cffi==1.16.0
26 | charset-normalizer==3.3.2
27 | click==8.1.7
28 | colorama==0.4.6
29 | contourpy==1.2.0
30 | crcmod==1.7
31 | cryptacular==1.6.2
32 | cryptography==42.0.5
33 | cycler==0.12.1
34 | dacite==1.7.0
35 | decorator==4.4.2
36 | decord==0.6.0
37 | deepspeed==0.14.0
38 | defusedxml==0.7.1
39 | Deprecated==1.2.14
40 | dill==0.3.8
41 | distro==1.9.0
42 | dnspython==2.6.1
43 | docker-pycreds==0.4.0
44 | einops==0.6.1
45 | exceptiongroup==1.2.0
46 | fastapi==0.110.0
47 | ffmpeg==1.4
48 | ffmpy==0.3.2
49 | fiftyone==0.23.6
50 | fiftyone-brain==0.16.1
51 | fiftyone_db==1.1.2
52 | filelock==3.9.0
53 | https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
54 | fonttools==4.49.0
55 | fsspec==2024.2.0
56 | ftfy==6.1.3
57 | future==1.0.0
58 | fvcore==0.1.5.post20221221
59 | gdown==5.1.0
60 | gitdb==4.0.11
61 | GitPython==3.1.42
62 | glob2==0.7
63 | google-auth==2.28.2
64 | google-auth-oauthlib==1.2.0
65 | gpustat==1.1.1
66 | gradio==4.21.0
67 | gradio_client==0.12.0
68 | graphql-core==3.2.3
69 | greenlet==3.0.3
70 | grpcio==1.62.1
71 | h11==0.14.0
72 | h2==4.1.0
73 | hjson==3.1.0
74 | hpack==4.0.0
75 | httpcore==1.0.4
76 | httpx==0.27.0
77 | huggingface-hub==0.21.4
78 | humanize==4.9.0
79 | hupper==1.12.1
80 | Hypercorn==0.16.0
81 | hyperframe==6.0.1
82 | idna==3.6
83 | idscheck==2.3.0
84 | imageio==2.27.0
85 | imageio-ffmpeg==0.4.9
86 | importlib_metadata==7.0.2
87 | importlib_resources==6.3.0
88 | inflate64==1.0.0
89 | iopath==0.1.10
90 | Jinja2==3.1.2
91 | jmespath==0.10.0
92 | joblib==1.3.2
93 | jsonlines==4.0.0
94 | jsonschema==4.21.1
95 | jsonschema-specifications==2023.12.1
96 | kaleido==0.2.1
97 | kiwisolver==1.4.5
98 | lazy_loader==0.3
99 | Markdown==3.6
100 | markdown-it-py==3.0.0
101 | MarkupSafe==2.1.3
102 | matplotlib==3.8.3
103 | mdurl==0.1.2
104 | mmcv-full==1.7.2
105 | model-index==0.1.11
106 | mongoengine==0.24.2
107 | motor==3.3.2
108 | moviepy==1.0.3
109 | mpmath==1.3.0
110 | multivolumefile==0.2.3
111 | networkx==3.2.1
112 | ninja==1.11.1.1
113 | numpy==1.23.5
114 | nvidia-cublas-cu11==11.11.3.6
115 | nvidia-cuda-cupti-cu11==11.8.87
116 | nvidia-cuda-nvrtc-cu11==11.8.89
117 | nvidia-cuda-runtime-cu11==11.8.89
118 | nvidia-cudnn-cu11==8.7.0.84
119 | nvidia-cufft-cu11==10.9.0.58
120 | nvidia-curand-cu11==10.3.0.86
121 | nvidia-cusolver-cu11==11.4.1.48
122 | nvidia-cusparse-cu11==11.7.5.86
123 | nvidia-ml-py==12.535.133
124 | nvidia-ml-py3==7.352.0
125 | nvidia-nccl-cu11==2.19.3
126 | nvidia-nvtx-cu11==11.8.86
127 | oauthlib==3.2.2
128 | omegaconf==2.3.0
129 | openai==1.14.0
130 | opencv-python==4.9.0.80
131 | opencv-python-headless==4.9.0.80
132 | opendatalab==0.0.10
133 | openmim==0.3.9
134 | openxlab==0.0.36
135 | ordered-set==4.1.0
136 | orjson==3.9.15
137 | oss2==2.17.0
138 | packaging==24.0
139 | pandas==1.5.3
140 | PasteDeploy==3.1.0
141 | pathtools==0.1.2
142 | pbkdf2==1.3
143 | peft==0.10.0
144 | pillow==10.2.0
145 | plaster==1.1.2
146 | plaster-pastedeploy==1.0.1
147 | platformdirs==4.2.0
148 | plotly==5.20.0
149 | portalocker==2.8.2
150 | pprintpp==0.4.0
151 | priority==2.0.0
152 | proglog==0.1.10
153 | protobuf==4.23.4
154 | psutil==5.9.4
155 | py-cpuinfo==9.0.0
156 | py7zr==0.21.0
157 | pyasn1==0.5.1
158 | pyasn1-modules==0.3.0
159 | pybcj==1.0.2
160 | pycparser==2.21
161 | pycryptodome==3.20.0
162 | pycryptodomex==3.20.0
163 | pydantic==2.6.4
164 | pydantic_core==2.16.3
165 | pydub==0.25.1
166 | Pygments==2.17.2
167 | pymongo==4.6.2
168 | pynvml==11.5.0
169 | pyparsing==3.1.2
170 | pyppmd==1.1.0
171 | pyramid==2.0.2
172 | pyramid-mailer==0.15.1
173 | PySocks==1.7.1
174 | python-dateutil==2.9.0.post0
175 | python-multipart==0.0.9
176 | python3-openid==3.2.0
177 | pytz==2023.4
178 | PyYAML==6.0
179 | pyzstd==0.15.9
180 | rarfile==4.1
181 | referencing==0.33.0
182 | regex==2023.12.25
183 | repoze.sendmail==4.4.1
184 | requests==2.28.2
185 | requests-oauthlib==1.4.0
186 | retrying==1.3.4
187 | rich==13.4.2
188 | rpds-py==0.18.0
189 | rsa==4.9
190 | ruff==0.3.2
191 | s3transfer==0.10.1
192 | safetensors==0.4.2
193 | scikit-image==0.22.0
194 | scikit-learn==1.4.1.post1
195 | scipy==1.10.1
196 | semantic-version==2.10.0
197 | sentencepiece==0.2.0
198 | sentry-sdk==1.42.0
199 | setproctitle==1.3.3
200 | shellingham==1.5.4
201 | six==1.16.0
202 | smmap==5.0.1
203 | sniffio==1.3.1
204 | sortedcontainers==2.4.0
205 | soupsieve==2.5
206 | SQLAlchemy==2.0.28
207 | sse-starlette==0.10.3
208 | sseclient-py==1.8.0
209 | starlette==0.36.3
210 | strawberry-graphql==0.138.1
211 | sympy==1.12
212 | tabulate==0.9.0
213 | taskgroup==0.0.0a4
214 | tenacity==8.2.3
215 | tensorboard==2.15.1
216 | tensorboard-data-server==0.7.2
217 | tensorboardX==2.6.2.2
218 | termcolor==2.3.0
219 | texttable==1.7.0
220 | threadpoolctl==3.3.0
221 | tifffile==2024.2.12
222 | timm==0.6.12
223 | tokenizers==0.15.2
224 | tomli==2.0.1
225 | tomlkit==0.12.0
226 | toolz==0.12.1
227 | torch==2.2.1
228 | torchaudio==2.2.1
229 | torchvision==0.17.1
230 | tqdm==4.65.2
231 | transaction==4.0
232 | transformers==4.37.1
233 | translationstring==1.4
234 | triton==2.2.0
235 | typer==0.9.0
236 | typing_extensions==4.8.0
237 | tzdata==2024.1
238 | tzlocal==5.2
239 | universal-analytics-python3==1.1.1
240 | urllib3==1.26.18
241 | uvicorn==0.28.0
242 | velruse==1.1.1
243 | venusian==3.1.0
244 | voxel51-eta==0.12.6
245 | wandb==0.14.0
246 | wcwidth==0.2.13
247 | WebOb==1.8.7
248 | websockets==11.0.3
249 | Werkzeug==3.0.1
250 | wrapt==1.16.0
251 | wsproto==1.2.0
252 | WTForms==3.1.2
253 | wtforms-recaptcha==0.3.2
254 | xmltodict==0.13.0
255 | yacs==0.1.8
256 | yapf==0.40.2
257 | zipp==3.18.1
258 | zope.deprecation==5.0
259 | zope.interface==6.2
260 | zope.sqlalchemy==3.1
261 |
--------------------------------------------------------------------------------
/scripts/accel_config_deepspeed_zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 8
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | machine_rank: 0
12 | main_training_function: main
13 | mixed_precision: bf16
14 | num_machines: 1
15 | num_processes: 4
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
22 |
--------------------------------------------------------------------------------
/scripts/accel_config_deepspeed_zero3_offload.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 2
5 | offload_optimizer_device: cpu
6 | offload_param_device: cpu
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/scripts/accel_config_deepspeed_zero3_offload_multinode.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | gradient_accumulation_steps: 2
6 | offload_optimizer_device: cpu
7 | offload_param_device: cpu
8 | zero3_init_flag: true
9 | zero3_save_16bit_model: true
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: 'no'
13 | machine_rank: 0
14 | main_process_ip: fdbd:dc61:18:8::20
15 | main_process_port: 6876
16 | main_training_function: main
17 | mixed_precision: bf16
18 | num_machines: 2
19 | num_processes: 16
20 | rdzv_backend: static
21 | same_network: true
22 | tpu_env: []
23 | tpu_use_cluster: false
24 | tpu_use_sudo: false
25 | use_cpu: false
26 |
--------------------------------------------------------------------------------
/scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | gradient_accumulation_steps: 2
6 | offload_optimizer_device: cpu
7 | offload_param_device: cpu
8 | zero3_init_flag: true
9 | zero3_save_16bit_model: true
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: 'no'
13 | machine_rank: 0
14 | main_process_ip: fdbd:dc61:18:8::20
15 | main_process_port: 6876
16 | main_training_function: main
17 | mixed_precision: bf16
18 | num_machines: 2
19 | num_processes: 16
20 | rdzv_backend: static
21 | same_network: true
22 | tpu_env: []
23 | tpu_use_cluster: false
24 | tpu_use_sudo: false
25 | use_cpu: false
26 |
--------------------------------------------------------------------------------
/scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | gradient_accumulation_steps: 2
6 | offload_optimizer_device: cpu
7 | offload_param_device: cpu
8 | zero3_init_flag: true
9 | zero3_save_16bit_model: true
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: 'no'
13 | machine_rank: 1
14 | main_process_ip: fdbd:dc61:18:8::20
15 | main_process_port: 6876
16 | main_training_function: main
17 | mixed_precision: bf16
18 | num_machines: 2
19 | num_processes: 16
20 | rdzv_backend: static
21 | same_network: true
22 | tpu_env: []
23 | tpu_use_cluster: false
24 | tpu_use_sudo: false
25 | use_cpu: false
26 |
--------------------------------------------------------------------------------
/scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 16
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: cpu
7 | offload_param_device: cpu
8 | zero3_init_flag: true
9 | zero3_save_16bit_model: true
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: 'no'
13 | machine_rank: 0
14 | main_training_function: main
15 | mixed_precision: bf16
16 | num_machines: 1
17 | num_processes: 1
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_env: []
21 | tpu_use_cluster: false
22 | tpu_use_sudo: false
23 | use_cpu: false
24 |
--------------------------------------------------------------------------------
/scripts/accel_config_multigpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: 0,1,2,3,4,5,6,7
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/scripts/accel_config_multinode.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 1
7 | main_process_ip: 10.193.16.150
8 | main_process_port: 6784
9 | main_training_function: main
10 | mixed_precision: bf16
11 | num_machines: 2
12 | num_processes: 16
13 | rdzv_backend: static
14 | same_network: true
15 | tpu_env: []
16 | tpu_use_cluster: false
17 | tpu_use_sudo: false
18 | use_cpu: false
19 |
--------------------------------------------------------------------------------
/scripts/accel_config_singlegpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: 'NO'
4 | downcast_bf16: 'no'
5 | gpu_ids: '0'
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 1
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | export OPENAI_API_KEY=YOUR_API_KEY
2 | num_frames=16
3 | test_ratio=1
4 |
5 | model_dir=MODELS/pllava-7b
6 | weight_dir=MODELS/pllava-7b
7 |
8 | lora_alpha=14
9 | selected_layers=(10)
10 | alphas=(0.4)
11 | taus=(0.8)
12 | temporal_segment_ratios=(0.25)
13 | cluster_ratios=(0.5)
14 |
15 | for alpha in "${alphas[@]}"; do
16 | for selected_layer in "${selected_layers[@]}"; do
17 | for tau in "${taus[@]}"; do
18 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do
19 | for cluster_ratio in "${cluster_ratios[@]}"; do
20 | # 执行命令
21 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio}
22 | mkdir -p "${SAVE_DIR}"
23 | conv_mode=eval_mvbench
24 | python -m tasks.eval.mvbench.pllava_eval_mvbench \
25 | --pretrained_model_name_or_path ${model_dir} \
26 | --save_path ${SAVE_DIR}/mvbench \
27 | --num_frames ${num_frames} \
28 | --use_lora \
29 | --lora_alpha ${lora_alpha} \
30 | --top_p 1.0 \
31 | --temperature 1.0 \
32 | --weight_dir ${weight_dir} \
33 | --pooling_shape 16-12-12 \
34 | --conv_mode ${conv_mode} \
35 | --selected_layer ${selected_layer} \
36 | --alpha ${alpha} \
37 | --tau ${tau} \
38 | --temporal_segment_ratio ${temporal_segment_ratio} \
39 | --cluster_ratio ${cluster_ratio}
40 | done
41 | done
42 | done
43 | done
44 | done
45 |
46 | lora_alpha=14
47 | selected_layers=(10)
48 | alphas=(0.4)
49 | taus=(0.8)
50 | temporal_segment_ratios=(0.25)
51 | cluster_ratios=(0.5)
52 |
53 | for alpha in "${alphas[@]}"; do
54 | for selected_layer in "${selected_layers[@]}"; do
55 | for tau in "${taus[@]}"; do
56 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do
57 | for cluster_ratio in "${cluster_ratios[@]}"; do
58 | # 执行命令
59 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio}
60 | mkdir -p "${SAVE_DIR}"
61 | conv_mode=eval_videomme
62 | python -m tasks.eval.videomme.pllava_eval_videomme \
63 | --pretrained_model_name_or_path ${model_dir} \
64 | --save_path ${SAVE_DIR}/videomme \
65 | --num_frames ${num_frames} \
66 | --use_lora \
67 | --lora_alpha ${lora_alpha} \
68 | --top_p 1.0 \
69 | --temperature 1.0 \
70 | --weight_dir ${weight_dir} \
71 | --pooling_shape 16-12-12 \
72 | --conv_mode ${conv_mode} \
73 | --selected_layer ${selected_layer} \
74 | --alpha ${alpha} \
75 | --tau ${tau} \
76 | --temporal_segment_ratio ${temporal_segment_ratio} \
77 | --cluster_ratio ${cluster_ratio}
78 | done
79 | done
80 | done
81 | done
82 | done
83 |
84 | lora_alpha=14
85 | selected_layers=(10)
86 | alphas=(0.4)
87 | taus=(0.8)
88 | temporal_segment_ratios=(0.25)
89 | cluster_ratios=(0.5)
90 |
91 | for alpha in "${alphas[@]}"; do
92 | for selected_layer in "${selected_layers[@]}"; do
93 | for tau in "${taus[@]}"; do
94 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do
95 | for cluster_ratio in "${cluster_ratios[@]}"; do
96 | # 执行命令
97 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio}
98 | mkdir -p "${SAVE_DIR}"
99 | conv_mode=eval_mvbench
100 | python -m tasks.eval.egoshcema.pllava_eval_egoschema \
101 | --pretrained_model_name_or_path ${model_dir} \
102 | --save_path ${SAVE_DIR}/egoschema \
103 | --num_frames ${num_frames} \
104 | --use_lora \
105 | --lora_alpha ${lora_alpha} \
106 | --top_p 1.0 \
107 | --temperature 1.0 \
108 | --weight_dir ${weight_dir} \
109 | --pooling_shape 16-12-12 \
110 | --conv_mode ${conv_mode} \
111 | --selected_layer ${selected_layer} \
112 | --alpha ${alpha} \
113 | --tau ${tau} \
114 | --temporal_segment_ratio ${temporal_segment_ratio} \
115 | --cluster_ratio ${cluster_ratio}
116 | done
117 | done
118 | done
119 | done
120 | done
121 |
122 |
123 | lora_alpha=4
124 | selected_layers=(5)
125 | alphas=(0.4)
126 | taus=(0.8)
127 | temporal_segment_ratios=(0.25)
128 | cluster_ratios=(0.5)
129 |
130 | for alpha in "${alphas[@]}"; do
131 | for selected_layer in "${selected_layers[@]}"; do
132 | for tau in "${taus[@]}"; do
133 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do
134 | for cluster_ratio in "${cluster_ratios[@]}"; do
135 | # 执行命令
136 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio}
137 | mkdir -p "${SAVE_DIR}"
138 | conv_mode=eval_vcgbench
139 | python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
140 | --pretrained_model_name_or_path ${model_dir} \
141 | --save_path ${SAVE_DIR}/vcgbench \
142 | --num_frames ${num_frames} \
143 | --weight_dir ${weight_dir} \
144 | --pooling_shape 16-12-12 \
145 | --test_ratio ${test_ratio} \
146 | --use_lora \
147 | --lora_alpha ${lora_alpha} \
148 | --selected_layer ${selected_layer} \
149 | --alpha ${alpha} \
150 | --tau ${tau} \
151 | --temporal_segment_ratio ${temporal_segment_ratio} \
152 | --cluster_ratio ${cluster_ratio}
153 | done
154 | done
155 | done
156 | done
157 | done
--------------------------------------------------------------------------------
/tasks/eval/__pycache__/eval_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/__pycache__/eval_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/__pycache__/eval_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/__pycache__/eval_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/tasks/eval/__pycache__/model_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/__pycache__/model_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/demo/__init__.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from gradio.themes.utils import colors, fonts, sizes
3 |
4 |
5 | pllava_theme = gr.themes.Monochrome(
6 | text_size="sm",
7 | spacing_size="sm",
8 | primary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"),
9 | secondary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"),
10 | neutral_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"),
11 | ).set(
12 | background_fill_primary_dark='*primary_950',
13 | background_fill_secondary_dark='*neutral_950'
14 | )
15 |
16 |
--------------------------------------------------------------------------------
/tasks/eval/demo/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/demo/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/demo/__pycache__/pllava_demo.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/demo/__pycache__/pllava_demo.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/demo/pllava_demo.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import copy
3 | import gradio as gr
4 | from gradio.themes.utils import colors, fonts, sizes
5 |
6 | from utils.easydict import EasyDict
7 | from tasks.eval.model_utils import load_pllava
8 | from tasks.eval.eval_utils import (
9 | ChatPllava,
10 | conv_plain_v1,
11 | Conversation,
12 | conv_templates
13 | )
14 | from tasks.eval.demo import pllava_theme
15 |
16 | SYSTEM="""You are Pllava, a large vision-language assistant.
17 | You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
18 | Follow the instructions carefully and explain your answers in detail based on the provided video.
19 | """
20 | INIT_CONVERSATION: Conversation = conv_plain_v1.copy()
21 |
22 |
23 | # ========================================
24 | # Model Initialization
25 | # ========================================
26 | def init_model(args):
27 |
28 | print('Initializing PLLaVA')
29 | model, processor = load_pllava(
30 | args.pretrained_model_name_or_path, args.num_frames,
31 | use_lora=args.use_lora,
32 | weight_dir=args.weight_dir,
33 | lora_alpha=args.lora_alpha,
34 | use_multi_gpus=args.use_multi_gpus)
35 | if not args.use_multi_gpus:
36 | model = model.to('cuda')
37 | chat = ChatPllava(model, processor)
38 | return chat
39 |
40 |
41 | # ========================================
42 | # Gradio Setting
43 | # ========================================
44 | def gradio_reset(chat_state, img_list):
45 | if chat_state is not None:
46 | chat_state = INIT_CONVERSATION.copy()
47 | if img_list is not None:
48 | img_list = []
49 | return (
50 | None,
51 | gr.update(value=None, interactive=True),
52 | gr.update(value=None, interactive=True),
53 | gr.update(placeholder='Please upload your video first', interactive=False),
54 | gr.update(value="Upload & Start Chat", interactive=True),
55 | chat_state,
56 | img_list
57 | )
58 |
59 |
60 | def upload_img(gr_img, gr_video, chat_state=None, num_segments=None, img_list=None):
61 | print(gr_img, gr_video)
62 | chat_state = INIT_CONVERSATION.copy() if chat_state is None else chat_state
63 | img_list = [] if img_list is None else img_list
64 |
65 | if gr_img is None and gr_video is None:
66 | return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None
67 | if gr_video:
68 | llm_message, img_list, chat_state = chat.upload_video(gr_video, chat_state, img_list, num_segments)
69 | return (
70 | gr.update(interactive=True),
71 | gr.update(interactive=True),
72 | gr.update(interactive=True, placeholder='Type and press Enter'),
73 | gr.update(value="Start Chatting", interactive=False),
74 | chat_state,
75 | img_list,
76 | )
77 | if gr_img:
78 | llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list)
79 | return (
80 | gr.update(interactive=True),
81 | gr.update(interactive=True),
82 | gr.update(interactive=True, placeholder='Type and press Enter'),
83 | gr.update(value="Start Chatting", interactive=False),
84 | chat_state,
85 | img_list
86 | )
87 |
88 |
89 | def gradio_ask(user_message, chatbot, chat_state, system):
90 | if len(user_message) == 0:
91 | return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
92 | chat_state = chat.ask(user_message, chat_state, system)
93 | chatbot = chatbot + [[user_message, None]]
94 | return '', chatbot, chat_state
95 |
96 |
97 | def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
98 | llm_message, llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=200, num_beams=num_beams, temperature=temperature)
99 | llm_message = llm_message.replace("", "") # handle
100 | chatbot[-1][1] = llm_message
101 | print(chat_state)
102 | print(f"Answer: {llm_message}")
103 | return chatbot, chat_state, img_list
104 |
105 |
106 | def parse_args():
107 | parser = ArgumentParser()
108 | parser.add_argument(
109 | "--pretrained_model_name_or_path",
110 | type=str,
111 | required=True,
112 | default='llava-hf/llava-1.5-7b-hf'
113 | )
114 | parser.add_argument(
115 | "--num_frames",
116 | type=int,
117 | required=True,
118 | default=4,
119 | )
120 | parser.add_argument(
121 | "--use_lora",
122 | action='store_true'
123 | )
124 | parser.add_argument(
125 | "--use_multi_gpus",
126 | action='store_true'
127 | )
128 | parser.add_argument(
129 | "--weight_dir",
130 | type=str,
131 | required=False,
132 | default=None,
133 | )
134 | parser.add_argument(
135 | "--conv_mode",
136 | type=str,
137 | required=False,
138 | default=None,
139 | )
140 | parser.add_argument(
141 | "--lora_alpha",
142 | type=int,
143 | required=False,
144 | default=None,
145 | )
146 | parser.add_argument(
147 | "--server_port",
148 | type=int,
149 | required=False,
150 | default=7868,
151 | )
152 | args = parser.parse_args()
153 | return args
154 |
155 |
156 | title = """
"""
157 | description = (
158 | """
159 | # PLLAVA!
160 | 
161 | - Upload A Video
162 | - Press Upload
163 | - Start Chatting
164 | """
165 | )
166 |
167 | args = parse_args()
168 |
169 | model_description = f"""
170 | # MODEL INFO
171 | - pretrained_model_name_or_path:{args.pretrained_model_name_or_path}
172 | - use_lora:{args.use_lora}
173 | - weight_dir:{args.weight_dir}
174 | """
175 |
176 | # with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
177 | with gr.Blocks(title="PLLaVA",
178 | theme=pllava_theme,
179 | css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
180 | gr.Markdown(title)
181 | gr.Markdown(description)
182 | gr.Markdown(model_description)
183 | with gr.Row():
184 | with gr.Column(scale=0.5, visible=True) as video_upload:
185 | # with gr.Column(elem_id="image", scale=0.5) as img_part:
186 | with gr.Tab("Video", elem_id='video_tab'):
187 | up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360)
188 | with gr.Tab("Image", elem_id='image_tab'):
189 | up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload", height=360)
190 | upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
191 | clear = gr.Button("Restart")
192 |
193 | # num_segments = gr.Slider(
194 | # minimum=8,
195 | # maximum=64,
196 | # value=8,
197 | # step=1,
198 | # interactive=True,
199 | # label="Video Segments",
200 | # )
201 |
202 | with gr.Column(visible=True) as input_raws:
203 | system_string = gr.Textbox(SYSTEM, interactive=True, label='system')
204 | num_beams = gr.Slider(
205 | minimum=1,
206 | maximum=5,
207 | value=1,
208 | step=1,
209 | interactive=True,
210 | label="beam search numbers",
211 | )
212 | temperature = gr.Slider(
213 | minimum=0.1,
214 | maximum=2.0,
215 | value=1.0,
216 | step=0.1,
217 | interactive=True,
218 | label="Temperature",
219 | )
220 |
221 | chat_state = gr.State()
222 | img_list = gr.State()
223 | chatbot = gr.Chatbot(elem_id="chatbot",label='Conversation')
224 | with gr.Row():
225 | with gr.Column(scale=0.7):
226 | text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False)
227 | with gr.Column(scale=0.15, min_width=0):
228 | run = gr.Button("💭Send")
229 | with gr.Column(scale=0.15, min_width=0):
230 | clear = gr.Button("🔄Clear")
231 |
232 | with gr.Row():
233 | examples = gr.Examples(
234 | examples=[
235 | ['example/jesse_dance.mp4', 'What is the man doing?'],
236 | ['example/yoga.mp4', 'What is the woman doing?'],
237 | ['example/cooking.mp4', 'Describe the background, characters and the actions in the provided video.'],
238 | # ['example/cooking.mp4', 'What is happening in the video?'],
239 | ['example/working.mp4', 'Describe the background, characters and the actions in the provided video.'],
240 | ['example/1917.mov', 'Describe the background, characters and the actions in the provided video.'],
241 | ],
242 | inputs=[up_video, text_input]
243 | )
244 |
245 |
246 | chat = init_model(args)
247 | INIT_CONVERSATION = conv_templates[args.conv_mode]
248 | upload_button.click(upload_img, [up_image, up_video, chat_state], [up_image, up_video, text_input, upload_button, chat_state, img_list])
249 |
250 | text_input.submit(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then(
251 | gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
252 | )
253 | run.click(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then(
254 | gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
255 | )
256 | run.click(lambda: "", None, text_input)
257 | clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_image, up_video, text_input, upload_button, chat_state, img_list], queue=False)
258 |
259 | # demo.queue(max_size=5)
260 | demo.launch(share=True,server_port=args.server_port)
261 | # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
262 |
--------------------------------------------------------------------------------
/tasks/eval/demo/show_compare.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import argparse
4 | import json
5 | import os
6 | import os.path as osp
7 | import gradio as gr
8 | import numpy as np
9 |
10 | from tasks.eval.recaption import load_results as load_results_recaption
11 | from tasks.eval.mvbench import load_results as load_results_mvbench
12 | from tasks.eval.vcgbench import load_results as load_results_vcgbench
13 | from tasks.eval.videoqabench import load_results as load_results_videoqabench
14 | from tasks.eval.demo import pllava_theme
15 |
16 |
17 | load_results_funcs = [
18 | load_results_recaption,
19 | load_results_mvbench,
20 | load_results_vcgbench,
21 | load_results_videoqabench,
22 | ]
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument(
28 | '--root_dir',
29 | required=True,
30 | )
31 | args = parser.parse_args()
32 | return args
33 |
34 | args = parse_args()
35 | root_dir = args.root_dir
36 |
37 | def show(result_list_first, result_list_second, result_index):
38 | sample2index_second = {}
39 |
40 | for i, result in enumerate(result_list_second):
41 | if 'video_path' not in result:
42 | continue
43 |
44 | question = result['question'] if 'question' in result else ''
45 | video_path = result['video_path']
46 | samplehash = question + '--' +video_path
47 | sample2index_second[samplehash] = i
48 |
49 | info = result_list_first[result_index]
50 | info_str_first = json.dumps(info, indent=4, ensure_ascii=False)
51 | video_path = info['video_path']
52 | question = info['question'] if 'question' in info else ''
53 | samplehash = question + '--' +video_path
54 | if samplehash in sample2index_second:
55 | info = result_list_second[sample2index_second[samplehash]]
56 | info_str_second = json.dumps(info, indent=4, ensure_ascii=False)
57 | else:
58 | info_str_second = f"NO {video_path} IN THE SECOND RESULT DIR"
59 | return video_path, info_str_first, info_str_second
60 |
61 | def reload_results_dirs():
62 | result_dirs = []
63 | # load result dir paths
64 | for dirpath, dirnames, filenames in os.walk(args.root_dir):
65 | if len(dirnames) == 0 and len(filenames) != 0:
66 | result_dirs.append(dirpath)
67 | return gr.Dropdown(result_dirs, value=result_dirs[0])
68 |
69 | def reload_results(result_dir):
70 | # if isinstance(result_dir, list):
71 | # result_dir = result_dir[0]
72 |
73 | if result_dir is None or not osp.exists(result_dir):
74 | return None
75 |
76 | for fn in load_results_funcs:
77 | result_list = fn(result_dir)
78 | if result_list is not None:
79 | np.random.shuffle(result_list)
80 | break
81 | result_index = gr.Slider(0, len(result_list), step=1)
82 |
83 | return result_list, result_index
84 |
85 |
86 |
87 | with gr.Blocks(title="PLLAVA RESULTS", theme=pllava_theme) as demo:
88 | result_list_first = gr.State()
89 | result_list_second = gr.State()
90 |
91 | with gr.Row():
92 | with gr.Column():
93 | gr.Markdown("# Showing off Model's Outputs.")
94 | gr.Markdown(
95 | "You can find all our results, including:\n"
96 | "1. results of Captioned Inter4k\n"
97 | "2. results of Different Benchmark inference outputs.\n"
98 | "Choose a directory to see the different output variant.\n"
99 | "You can also choose secondary directory (as long as they are from the same dataset.) to compare on the results.\n"
100 | )
101 |
102 | with gr.Row():
103 | with gr.Column():
104 | show_video = gr.Video(interactive=False)
105 |
106 | with gr.Column():
107 | button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory')
108 | result_index = gr.Slider(0, 0, step=1, label="Index")
109 |
110 | result_dir_first = gr.Dropdown(label='Test Result Path')
111 | info_first = gr.Text(interactive=False, label='Detailed Output Information')
112 | result_dir_second = gr.Dropdown(label='Test Result Path')
113 | info_second = gr.Text(interactive=False, label='Detailed Output Information')
114 |
115 |
116 | button_reload.click(reload_results_dirs, [], [result_dir_first])
117 | button_reload.click(reload_results_dirs, [], [result_dir_second])
118 | result_dir_first.change(reload_results, [result_dir_first], [result_list_first, result_index])
119 | result_dir_second.change(reload_results, [result_dir_second], [result_list_second, result_index])
120 | result_index.change(show, [result_list_first, result_list_second, result_index], [show_video, info_first, info_second])
121 | demo.load(reload_results_dirs, [], [result_dir_first])
122 | demo.load(reload_results_dirs, [], [result_dir_second])
123 |
124 | demo.launch(share=True)
--------------------------------------------------------------------------------
/tasks/eval/demo/show_gallery.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import argparse
4 | import json
5 | import os
6 | import os.path as osp
7 | import gradio as gr
8 |
9 | from tasks.eval.recaption import load_results as load_results_recaption
10 | from tasks.eval.mvbench import load_results as load_results_mvbench
11 | from tasks.eval.vcgbench import load_results as load_results_vcgbench
12 | from tasks.eval.videoqabench import load_results as load_results_videoqabench
13 |
14 | load_results_funcs = [
15 | load_results_recaption,
16 | load_results_mvbench,
17 | load_results_vcgbench,
18 | load_results_videoqabench,
19 | ]
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument(
25 | '--root_dir',
26 | required=True,
27 | )
28 | args = parser.parse_args()
29 | return args
30 |
31 | args = parse_args()
32 | root_dir = args.root_dir
33 |
34 | def show(result_list, result_index):
35 | info = result_list[result_index]
36 | video_path = info['video_path']
37 | info_str = json.dumps(info, indent=4)
38 | return video_path, info_str
39 |
40 | def reload_results_dirs():
41 | result_dirs = []
42 | # load result dir paths
43 | for dirpath, dirnames, filenames in os.walk(args.root_dir):
44 | if len(dirnames) == 0 and len(filenames) != 0:
45 | result_dirs.append(dirpath)
46 | return gr.Dropdown(result_dirs, value=result_dirs[0])
47 |
48 | def reload_results(result_dir):
49 | # if isinstance(result_dir, list):
50 | # result_dir = result_dir[0]
51 |
52 | if result_dir is None or not osp.exists(result_dir):
53 | return None
54 |
55 | for fn in load_results_funcs:
56 | result_list = fn(result_dir)
57 | if result_list is not None:
58 | break
59 |
60 | result_index = gr.Slider(0, len(result_list), step=1)
61 |
62 | return result_list, result_index
63 |
64 | with gr.Blocks() as demo:
65 | result_list = gr.State()
66 |
67 | with gr.Row():
68 | gr.Markdown("# Showing of what has came out.")
69 |
70 | with gr.Row():
71 | with gr.Column(scale=1):
72 | gr.Markdown(f"### From Saved Results Directory {args.root_dir}")
73 |
74 | with gr.Column(scale=2):
75 | result_dir = gr.Dropdown(label='Test Result Path')
76 | button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory')
77 |
78 |
79 |
80 | with gr.Row():
81 | with gr.Column():
82 | show_video = gr.Video(interactive=False)
83 |
84 | with gr.Column():
85 | result_index = gr.Slider(0, 0, step=1, label="Index")
86 | info = gr.Text(interactive=False, label='Detailed Output Information')
87 |
88 |
89 | button_reload.click(reload_results_dirs, [], [result_dir])
90 | result_dir.change(reload_results, [result_dir], [result_list, result_index])
91 | result_index.change(show, [result_list, result_index], [show_video, info])
92 | demo.load(reload_results_dirs, [], [result_dir])
93 |
94 | demo.launch(share=True)
--------------------------------------------------------------------------------
/tasks/eval/egoshcema/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from tasks.eval.eval_utils import (
5 | dump_json,
6 | load_json,
7 | EvalDataset,
8 | )
9 |
10 | def check_ans(pred, gt):
11 | flag = False
12 | pred_list = pred.lower().split(' ')
13 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
14 | gt_list = gt.lower().split(' ')
15 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
16 | if gt_content[-1] == '.':
17 | gt_content = gt_content[:-1]
18 |
19 | if not any([c in pred_option for c in 'abcdefgABCDEFG']):
20 | print(f"model doesn't follow instructions: {pred}")
21 | elif pred_option.replace('.', '') in gt_option:
22 | flag = True
23 | elif gt_option in pred_option:
24 | flag = True
25 |
26 | return flag
27 |
28 | def save_results(result_list, save_path):
29 | final_res, acc_dict = {}, {}
30 | correct, total = 0, 0
31 | for res in result_list:
32 | task_type = res['task_type']
33 | if task_type not in acc_dict:
34 | acc_dict[task_type] = [0, 0] # correct, total
35 | acc_dict[task_type][1] += 1
36 | total += 1
37 | pred = res['pred']
38 | gt = res['gt']
39 | if check_ans(pred=pred, gt=gt):
40 | acc_dict[task_type][0] += 1
41 | correct += 1
42 |
43 | for k, v in acc_dict.items():
44 | final_res[k] = v[0] / v[1] * 100
45 | correct += v[0]
46 | total += v[1]
47 | final_res['Avg'] = correct / total * 100
48 |
49 | all_results = {
50 | "acc_dict": acc_dict,
51 | "result_list": result_list
52 | }
53 | dump_json(all_results, save_path, 'all_results.json')
54 | dump_json(final_res, save_path, 'upload_leaderboard.json')
55 |
56 | def load_results(save_path):
57 | all_results = load_json(save_path, 'all_results.json')
58 | if all_results is not None:
59 | result_list = all_results['result_list']
60 | else:
61 | result_list = None
62 | # json_data = load_json(save_path, 'all_results.json')['result_list']
63 | return result_list
64 |
65 | class EgoSchemaDataset(EvalDataset):
66 | data_list_info = {
67 | "FullSet": ("egoschema_fullset.json", "DATAS/ego_schema/videos", "video", False), # has start & end
68 | }
69 | data_dir = "DATAS/ego_schema/json"
70 |
71 | def __init__(self, *args, **kwargs):
72 | super().__init__(*args, **kwargs)
73 |
74 | data_list_info = self.data_list_info
75 | data_dir = self.data_dir
76 |
77 | self.data_list = []
78 | for k, v in data_list_info.items():
79 | with open(os.path.join(data_dir, v[0]), 'r') as f:
80 | json_data = json.load(f)
81 | for data in json_data:
82 | self.data_list.append({
83 | 'task_type': k,
84 | 'prefix': v[1],
85 | 'data_type': v[2],
86 | 'bound': v[3],
87 | 'data': data
88 | })
89 | # self.data_list = self.data_list[:100] # for debug
90 | self.decord_method = {
91 | 'video': self.read_video,
92 | 'gif': self.read_gif,
93 | 'frame': self.read_frame,
94 | 'npy': self.read_npy,
95 | }
96 |
97 | # # transform
98 | # crop_size = resolution
99 | # scale_size = resolution
100 | # input_mean = [0.48145466, 0.4578275, 0.40821073]
101 | # input_std = [0.26862954, 0.26130258, 0.27577711]
102 | # self.transform = T.Compose([
103 | # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
104 | # GroupCenterCrop(crop_size),
105 | # Stack(),
106 | # ToTorchFormatTensor(),
107 | # GroupNormalize(input_mean, input_std)
108 | # ])
109 |
110 | def __getitem__(self, idx):
111 | question, answer = self.qa_template(self.data_list[idx]['data'])
112 | task_type = self.data_list[idx]['task_type']
113 | decord_method = self.decord_method[self.data_list[idx]['data_type']]
114 | bound = None
115 | if self.data_list[idx]['bound']:
116 | bound = (
117 | self.data_list[idx]['data']['start'],
118 | self.data_list[idx]['data']['end'],
119 | )
120 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
121 |
122 |
123 | # images_group = decord_method(video_path, bound)
124 | images_group = decord_method(video_path, bound)
125 | # try: # might be problem with decord
126 | # images_group = decord_method(video_path, bound)
127 | # except Exception as e:
128 | # print(f'error decoding {video_path}', e)
129 | # task_type = 'error_reading_video'
130 | # images_group = None
131 |
132 | return {
133 | 'video_path': video_path,
134 | 'video_pils': images_group, # some might use the original pils and do their own transforms
135 | 'question': question,
136 | 'answer': answer,
137 | 'task_type': task_type,
138 | }
139 |
140 |
141 | def qa_template(self, data):
142 | question = f"Question: {data['question']}\n"
143 | question += "Options:\n"
144 | answer = data['answer']
145 | answer_idx = -1
146 | for idx, c in enumerate(data['candidates']):
147 | question += f"({chr(ord('A') + idx)}) {c}\n"
148 | if c == answer:
149 | answer_idx = idx
150 | question = question.rstrip()
151 | answer = f"({chr(ord('A') + answer_idx)}) {answer}"
152 | return question, answer
153 |
154 |
--------------------------------------------------------------------------------
/tasks/eval/egoshcema/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/egoshcema/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/egoshcema/__pycache__/pllava_eval_egoschema.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/egoshcema/__pycache__/pllava_eval_egoschema.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/mvbench/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from tasks.eval.eval_utils import (
5 | dump_json,
6 | load_json,
7 | EvalDataset,
8 | )
9 |
10 |
11 | def check_ans(pred, gt):
12 | flag = False
13 |
14 | pred_list = pred.lower().split(' ')
15 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
16 | gt_list = gt.lower().split(' ')
17 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
18 | if gt_content[-1] == '.':
19 | gt_content = gt_content[:-1]
20 |
21 | if not any([c in pred_option for c in 'abcdefgABCDEFG']):
22 | print(f"model doesn't follow instructions: {pred}")
23 | elif pred_option.replace('.', '') in gt_option:
24 | flag = True
25 | elif gt_option in pred_option:
26 | flag = True
27 |
28 | return flag
29 |
30 | def save_results(result_list, save_path):
31 |
32 | final_res, acc_dict = {}, {}
33 | correct, total = 0, 0
34 | for res in result_list:
35 | task_type = res['task_type']
36 | if task_type not in acc_dict:
37 | acc_dict[task_type] = [0, 0] # correct, total
38 | acc_dict[task_type][1] += 1
39 | total += 1
40 | pred = res['pred']
41 | gt = res['gt']
42 | if check_ans(pred=pred, gt=gt):
43 | acc_dict[task_type][0] += 1
44 | correct += 1
45 |
46 | for k, v in acc_dict.items():
47 | final_res[k] = v[0] / v[1] * 100
48 | correct += v[0]
49 | total += v[1]
50 | final_res['Avg'] = correct / total * 100
51 |
52 | all_results = {
53 | "acc_dict": acc_dict,
54 | "result_list": result_list
55 | }
56 | dump_json(all_results, save_path, 'all_results.json')
57 | dump_json(final_res, save_path, 'upload_leaderboard.json')
58 |
59 | def load_results(save_path):
60 | all_results = load_json(save_path, 'all_results.json')
61 | if all_results is not None:
62 | result_list = all_results['result_list']
63 | else:
64 | result_list = None
65 | # json_data = load_json(save_path, 'all_results.json')['result_list']
66 | return result_list
67 |
68 | class MVBenchDataset(EvalDataset):
69 | data_list_info = {
70 | # "task_type (sub task name)": ("json file name", "image/video prefix", "data_type", "bound")
71 | "Action Sequence": ("action_sequence.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
72 | "Action Prediction": ("action_prediction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
73 | "Action Antonym": ("action_antonym.json", "DATAS/MVBench/video/ssv2_video/", "video", False),
74 | "Fine-grained Action": ("fine_grained_action.json", "DATAS/MVBench/video/Moments_in_Time_Raw/videos/", "video", False),
75 | "Unexpected Action": ("unexpected_action.json", "DATAS/MVBench/video/FunQA_test/test/", "video", False),
76 | "Object Existence": ("object_existence.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
77 | "Object Interaction": ("object_interaction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
78 | "Object Shuffle": ("object_shuffle.json", "DATAS/MVBench/video/perception/videos/", "video", False),
79 | "Moving Direction": ("moving_direction.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
80 | "Action Localization": ("action_localization.json", "DATAS/MVBench/video/sta/sta_video/", "video", True), # has start & end
81 | "Scene Transition": ("scene_transition.json", "DATAS/MVBench/video/scene_qa/video/", "video", False),
82 | "Action Count": ("action_count.json", "DATAS/MVBench/video/perception/videos/", "video", False),
83 | "Moving Count": ("moving_count.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
84 | "Moving Attribute": ("moving_attribute.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
85 | "State Change": ("state_change.json", "DATAS/MVBench/video/perception/videos/", "video", False),
86 | "Fine-grained Pose": ("fine_grained_pose.json", "DATAS/MVBench/video/nturgbd/", "video", False),
87 | "Character Order": ("character_order.json", "DATAS/MVBench/video/perception/videos/", "video", False),
88 | "Egocentric Navigation": ("egocentric_navigation.json", "DATAS/MVBench/video/vlnqa/", "video", False),
89 | "Episodic Reasoning": ("episodic_reasoning.json", "DATAS/MVBench/video/tvqa/frames_fps3_hq/", "frame", True), # has start & end, read frame
90 | "Counterfactual Inference": ("counterfactual_inference.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
91 | }
92 | data_dir = "DATAS/MVBench/json"
93 |
94 | def __init__(self, *args, **kwargs):
95 | super().__init__(*args, **kwargs)
96 |
97 | data_list_info = self.data_list_info
98 | data_dir = self.data_dir
99 |
100 | self.data_list = []
101 | for k, v in data_list_info.items():
102 | with open(os.path.join(data_dir, v[0]), 'r') as f:
103 | json_data = json.load(f)
104 | for data in json_data:
105 | self.data_list.append({
106 | 'task_type': k,
107 | 'prefix': v[1],
108 | 'data_type': v[2],
109 | 'bound': v[3],
110 | 'data': data
111 | })
112 | # self.data_list = self.data_list[:100] # for debug
113 | self.decord_method = {
114 | 'video': self.read_video,
115 | 'gif': self.read_gif,
116 | 'frame': self.read_frame,
117 | 'npy': self.read_npy,
118 | }
119 |
120 | # # transform
121 | # crop_size = resolution
122 | # scale_size = resolution
123 | # input_mean = [0.48145466, 0.4578275, 0.40821073]
124 | # input_std = [0.26862954, 0.26130258, 0.27577711]
125 | # self.transform = T.Compose([
126 | # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
127 | # GroupCenterCrop(crop_size),
128 | # Stack(),
129 | # ToTorchFormatTensor(),
130 | # GroupNormalize(input_mean, input_std)
131 | # ])
132 |
133 | def __getitem__(self, idx):
134 | question, answer = self.qa_template(self.data_list[idx]['data'])
135 | task_type = self.data_list[idx]['task_type']
136 | decord_method = self.decord_method[self.data_list[idx]['data_type']]
137 | bound = None
138 | if self.data_list[idx]['bound']:
139 | bound = (
140 | self.data_list[idx]['data']['start'],
141 | self.data_list[idx]['data']['end'],
142 | )
143 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
144 |
145 |
146 | images_group = decord_method(video_path, bound)
147 |
148 | return {
149 | 'video_path': video_path,
150 | 'video_pils': images_group, # some might use the original pils and do their own transforms
151 | 'question': question,
152 | 'answer': answer,
153 | 'task_type': task_type,
154 | }
155 |
156 |
157 | def qa_template(self, data):
158 | question = f"Question: {data['question']}\n"
159 | question += "Options:\n"
160 | answer = data['answer']
161 | answer_idx = -1
162 | for idx, c in enumerate(data['candidates']):
163 | question += f"({chr(ord('A') + idx)}) {c}\n"
164 | if c == answer:
165 | answer_idx = idx
166 | question = question.rstrip()
167 | answer = f"({chr(ord('A') + answer_idx)}) {answer}"
168 | return question, answer
169 |
170 |
--------------------------------------------------------------------------------
/tasks/eval/mvbench/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/mvbench/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/tasks/eval/mvbench/__pycache__/llava_next_video_mvbench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/llava_next_video_mvbench.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/mvbench/__pycache__/pllava_eval_mvbench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/pllava_eval_mvbench.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/mvbench/__pycache__/tarsier_eval_mvbench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/tarsier_eval_mvbench.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/recaption/show_recaption.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import gradio as gr
4 |
5 | from tasks.eval.recaption import load_results
6 | import json
7 |
8 | # example = videogallery().example_inputs()
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | '--save_path',
15 | required=True,
16 | )
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | args = parse_args()
22 | result_list = load_results(args.save_path)
23 |
24 |
25 | def show(result_index, ):
26 | info = result_list[result_index]
27 | video_path = info['video_path']
28 | info_str = json.dumps(info, indent=4)
29 | return video_path, info_str
30 |
31 |
32 |
33 | from tasks.eval.recaption import load_results
34 |
35 | with gr.Blocks() as demo:
36 | gr.Markdown("# Showing of what has came out.")
37 | gr.Markdown(f"From Saved Results {args.save_path}")
38 | with gr.Row():
39 | with gr.Column(1):
40 | show_video = gr.Video(interactive=False)
41 |
42 | with gr.Column():
43 | result_index = gr.Slider(0, len(result_list), step=1)
44 | info = gr.Text(interactive=False)
45 |
46 | result_index.change(show, [result_index], [show_video, info])
47 |
48 |
49 |
50 |
51 |
52 | demo.launch(share=True)
53 |
--------------------------------------------------------------------------------
/tasks/eval/vcgbench/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/vcgbench/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/vcgbench/__pycache__/pllava_eval_vcgbench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/vcgbench/__pycache__/pllava_eval_vcgbench.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/vcgbench/show_vcg.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import gradio as gr
4 |
5 | from tasks.eval.vcgbench import load_results
6 | import json
7 |
8 | # example = videogallery().example_inputs()
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | '--save_path',
15 | required=True,
16 | )
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | args = parse_args()
22 | result_list = load_results(args.save_path)
23 |
24 |
25 | def show(result_index, ):
26 | info = result_list[result_index]
27 | video_path = info['video_path']
28 | info_str = json.dumps(info, indent=4)
29 | return video_path, info_str
30 |
31 | with gr.Blocks() as demo:
32 | gr.Markdown(
33 | f"# Showing The Results from {args.save_path}"
34 | )
35 | with gr.Row():
36 | with gr.Column():
37 | show_video = gr.Video(interactive=False)
38 |
39 | with gr.Column():
40 | result_index = gr.Slider(0, len(result_list), step=1)
41 | info = gr.Text(interactive=False)
42 |
43 | result_index.change(show, [result_index], [show_video, info])
44 |
45 | demo.launch(share=True)
46 |
--------------------------------------------------------------------------------
/tasks/eval/videomme/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from tasks.eval.eval_utils import (
5 | dump_json,
6 | load_json,
7 | EvalDataset,
8 | )
9 |
10 |
11 | def check_ans(pred, gt):
12 | flag = False
13 |
14 | pred_list = pred.lower().split(' ')
15 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
16 | gt_list = gt.lower().split(' ')
17 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
18 | if gt_content[-1] == '.':
19 | gt_content = gt_content[:-1]
20 |
21 | if not any([c in pred_option for c in 'abcdefgABCDEFG']):
22 | print(f"model doesn't follow instructions: {pred}")
23 | elif pred_option.replace('.', '') in gt_option:
24 | flag = True
25 | elif gt_option in pred_option:
26 | flag = True
27 |
28 | return flag
29 |
30 | def save_results(result_list, save_path):
31 |
32 | final_res, acc_dict = {}, {}
33 | correct, total = 0, 0
34 | for res in result_list:
35 | task_type = res['task_type']
36 | if task_type not in acc_dict:
37 | acc_dict[task_type] = [0, 0] # correct, total
38 | acc_dict[task_type][1] += 1
39 | total += 1
40 | pred = res['pred']
41 | gt = res['gt']
42 | if check_ans(pred=pred, gt=gt):
43 | acc_dict[task_type][0] += 1
44 | correct += 1
45 |
46 | for k, v in acc_dict.items():
47 | final_res[k] = v[0] / v[1] * 100
48 | correct += v[0]
49 | total += v[1]
50 | final_res['Avg'] = correct / total * 100
51 |
52 | all_results = {
53 | "acc_dict": acc_dict,
54 | "result_list": result_list
55 | }
56 | dump_json(all_results, save_path, 'all_results.json')
57 | dump_json(final_res, save_path, 'upload_leaderboard.json')
58 |
59 | def load_results(save_path):
60 | all_results = load_json(save_path, 'all_results.json')
61 | if all_results is not None:
62 | result_list = all_results['result_list']
63 | else:
64 | result_list = None
65 | # json_data = load_json(save_path, 'all_results.json')['result_list']
66 | return result_list
67 |
68 | class VideoMMEDataset(EvalDataset):
69 | data_list_info = {
70 | # "task_type (sub task name)": ("json file name", "image/video prefix", "data_type", "bound")
71 | "Short Video": ("short.json", "DATAS/Video-MME/data", "video", False), # has start & end
72 | "Medium Video": ("medium.json", "DATAS/Video-MME/data", "video", False), # has start & end
73 | "Long Video": ("long.json", "DATAS/Video-MME/data", "video", False),
74 | }
75 | data_dir = "DATAS/Video-MME/json"
76 |
77 | def __init__(self, *args, **kwargs):
78 | super().__init__(*args, **kwargs)
79 |
80 | data_list_info = self.data_list_info
81 | data_dir = self.data_dir
82 |
83 | self.data_list = []
84 | for k, v in data_list_info.items():
85 | with open(os.path.join(data_dir, v[0]), 'r') as f:
86 | json_data = json.load(f)
87 | for data in json_data:
88 | self.data_list.append({
89 | 'task_type': k,
90 | 'prefix': v[1],
91 | 'data_type': v[2],
92 | 'bound': v[3],
93 | 'data': data
94 | })
95 | # self.data_list = self.data_list[:100] # for debug
96 | self.decord_method = {
97 | 'video': self.read_video,
98 | 'gif': self.read_gif,
99 | 'frame': self.read_frame,
100 | 'npy': self.read_npy,
101 | }
102 |
103 | # # transform
104 | # crop_size = resolution
105 | # scale_size = resolution
106 | # input_mean = [0.48145466, 0.4578275, 0.40821073]
107 | # input_std = [0.26862954, 0.26130258, 0.27577711]
108 | # self.transform = T.Compose([
109 | # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
110 | # GroupCenterCrop(crop_size),
111 | # Stack(),
112 | # ToTorchFormatTensor(),
113 | # GroupNormalize(input_mean, input_std)
114 | # ])
115 |
116 | def __getitem__(self, idx):
117 | question, answer = self.qa_template(self.data_list[idx]['data'])
118 | task_type = self.data_list[idx]['task_type']
119 | decord_method = self.decord_method[self.data_list[idx]['data_type']]
120 | bound = None
121 | if self.data_list[idx]['bound']:
122 | bound = (
123 | self.data_list[idx]['data']['start'],
124 | self.data_list[idx]['data']['end'],
125 | )
126 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
127 |
128 |
129 | # images_group = decord_method(video_path, bound)
130 | images_group = decord_method(video_path, bound)
131 | # try: # might be problem with decord
132 | # images_group = decord_method(video_path, bound)
133 | # except Exception as e:
134 | # print(f'error decoding {video_path}', e)
135 | # task_type = 'error_reading_video'
136 | # images_group = None
137 |
138 | return {
139 | 'video_path': video_path,
140 | 'video_pils': images_group, # some might use the original pils and do their own transforms
141 | 'question': question,
142 | 'answer': answer,
143 | 'task_type': task_type,
144 | }
145 |
146 |
147 | def qa_template(self, data):
148 | question = f"Question: {data['question']}\n"
149 | question += "Options:\n"
150 | answer = data['answer']
151 | answer_idx = -1
152 | for idx, c in enumerate(data['candidates']):
153 | question += f"({chr(ord('A') + idx)}) {c}\n"
154 | if c == answer:
155 | answer_idx = idx
156 | question = question.rstrip()
157 | answer = f"({chr(ord('A') + answer_idx)}) {answer}"
158 | return question, answer
159 |
160 |
--------------------------------------------------------------------------------
/tasks/eval/videomme/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videomme/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/videomme/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videomme/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/tasks/eval/videomme/__pycache__/pllava_eval_videomme.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videomme/__pycache__/pllava_eval_videomme.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/videoqabench/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videoqabench/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/eval/videoqabench/__pycache__/pllava_eval_videoqabench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videoqabench/__pycache__/pllava_eval_videoqabench.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/shared_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | from torch.utils.data import ConcatDataset, DataLoader
9 |
10 | from utils.optimizer import create_optimizer
11 | from utils.scheduler import create_scheduler
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def get_media_types(datasources):
17 | """get the media types for for all the dataloaders.
18 |
19 | Args:
20 | datasources (List): List of dataloaders or datasets.
21 |
22 | Returns: List. The media_types.
23 |
24 | """
25 | if isinstance(datasources[0], DataLoader):
26 | datasets = [dataloader.dataset for dataloader in datasources]
27 | else:
28 | datasets = datasources
29 | media_types = [
30 | dataset.datasets[0].media_type
31 | if isinstance(dataset, ConcatDataset)
32 | else dataset.media_type
33 | for dataset in datasets
34 | ]
35 |
36 | return media_types
37 |
--------------------------------------------------------------------------------
/tasks/train/__pycache__/instruction_data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/train/__pycache__/instruction_data.cpython-310.pyc
--------------------------------------------------------------------------------
/tasks/train/clever_process.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 |
4 | dataset_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/CLEVRER'
5 | dir_list = os.listdir(dataset_path)
6 |
7 | for dir in dir_list:
8 | dir_path = os.path.join(dataset_path, dir)
9 | file_list = os.listdir(dir_path)
10 | for file in file_list:
11 | file_path = os.path.join(dir_path, file)
12 | shutil.move(file_path, dataset_path)
--------------------------------------------------------------------------------
/tasks/train/config_pllava_nframe.py:
--------------------------------------------------------------------------------
1 | from tasks.train.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | # train_corpus = "videochat2_instruction"
5 | train_corpus = "videochat2_instruction_full"
6 |
7 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
8 | test_file = dict()
9 | test_types = []
10 | num_workers = 8
11 | save_steps=1000
12 | ckpt_steps=1000
13 | stop_key = None
14 | deepspeed=False
15 | # ========================= input ==========================
16 | num_frames = 16
17 | num_frames_test = 1
18 | batch_size = 8
19 | gradient_accumulation_steps=1
20 | max_txt_l = 512
21 | max_train_steps=None
22 | pre_text = False
23 | inputs = dict(
24 | image_res=336,
25 | video_input=dict(
26 | num_frames="${num_frames}",
27 | sample_type="rand",
28 | num_frames_test="${num_frames_test}",
29 | sample_type_test="middle",
30 | random_aug=False,
31 | ),
32 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
33 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
34 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
35 | )
36 |
37 | # ========================= model ==========================
38 | model = dict(
39 | # repo_id="llava-hf/llava-v1.6-vicuna-7b-hf",
40 | repo_id="MODELS/llava-1.6",
41 | # repo_id="MODELS/llava-1.6-7b-next-video-dpo",
42 | # repo_id="MODELS/tarsier",
43 | pretrained_path=None,
44 | load_from_origin=False,
45 | origin_vision="",
46 | origin_llm="",
47 | vision_encoder=dict(
48 | name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model
49 | ),
50 | torch_dtype='bfloat16',
51 | freeze_projector=False,
52 | projector_unfreeze_modules = ['all'],
53 | freeze_lm=True,
54 | lm_unfreeze_modules=['all'],
55 | # lm_unfreeze_modules=['layernorm', 'embed_tokens', 'norm', 'lm_head'],
56 | freeze_vision_tower=True,
57 | lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma
58 | use_lora=True,
59 | lora_r=128,
60 | lora_alpha=32,
61 | lora_dropout=0.05,
62 | num_frames="${num_frames}",
63 | pooling_method='avg',
64 | use_pooling=True,
65 | frame_shape=(24,24),
66 | pooling_shape=(16,8,8),
67 | )
68 |
69 | preprocess = dict(
70 | system="",
71 | mm_alone=True,
72 | random_shuffle=True,
73 | add_second_msg=True,
74 | roles=['USER:', 'ASSISTANT:'],
75 | end_signal=(' ', ''),
76 | begin_signal='',
77 | dataset_image_placeholder='',
78 | dataset_video_placeholder='',
79 | image_token_index=32000,
80 | max_txt_l = "${max_txt_l}",
81 | ignore_index=-100, # same as torch softmax ignore index
82 | center_pad=False,
83 | longest_edge=762,
84 | shortest_edge=336,
85 | clip_transform=False,
86 | num_frames="${num_frames}",
87 | )
88 |
89 |
90 | optimizer = dict(
91 | opt="adamW",
92 | lr=2e-5,
93 | opt_betas=[0.9, 0.999], # default
94 | weight_decay=0.02,
95 | max_grad_norm=-1, # requires a positive float, use -1 to disable
96 | # use a different lr for some modules, e.g., larger lr for new modules
97 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
98 | )
99 |
100 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
101 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
102 | scheduler = dict(
103 | is_videochat2_custom=False,
104 | sched="cosine",
105 | epochs=2,
106 | warmup_ratio=0.2,
107 | min_lr_multi=0.25)
108 |
109 | evaluate = False
110 | deep_fusion = False
111 | evaluation = dict(
112 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
113 | eval_x_only=False,
114 | k_test=128,
115 | eval_offload=True, # offload gpu tensors to cpu to save memory.
116 | )
117 |
118 | fp16 = True
119 | gradient_checkpointing = True
120 |
121 | # ========================= wandb ==========================
122 | wandb = dict(
123 | enable=False,
124 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
125 | project="DE_LLAVA", # setup in your command line
126 | )
127 | dist_url = "env://"
128 | device = "cuda"
129 | mode = "it"
130 |
131 | # ========================= others ==========================
132 | output_dir = None # output dir
133 | resume = False # if True, load optimizer and scheduler states as well
134 | debug = False
135 | log_freq = 5
136 | metric_window_size=10 # window size for metric
137 | seed = 42
138 | report_to='tensorboard'
139 | save_latest = True
140 | auto_resume = True
141 | pretrained_path = "" # path to pretrained model weights, for resume only?
142 |
--------------------------------------------------------------------------------
/tasks/train/config_pllava_nframe_yiprompt.py:
--------------------------------------------------------------------------------
1 | from tasks.train.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | # train_corpus = "videochat2_instruction"
5 | train_corpus = "videochat2_instruction_full"
6 |
7 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
8 | test_file = dict()
9 | test_types = []
10 | num_workers = 8
11 | save_steps=10000
12 | ckpt_steps=1000
13 | stop_key = None
14 | deepspeed=False
15 | highres=None
16 | # ========================= input ==========================
17 | num_frames = 16
18 | num_frames_test = 1
19 | batch_size = 1
20 | gradient_accumulation_steps=16
21 | max_txt_l = 512
22 | max_train_steps=None
23 | pre_text = False
24 | inputs = dict(
25 | image_res=336,
26 | video_input=dict(
27 | num_frames="${num_frames}",
28 | sample_type="rand",
29 | num_frames_test="${num_frames_test}",
30 | sample_type_test="middle",
31 | random_aug=False,
32 | ),
33 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
34 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
35 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
36 | )
37 |
38 | model = dict(
39 | repo_id="llava-hf/llava-1.5-7b-hf",
40 | pretrained_path=None,
41 | load_from_origin=False,
42 | origin_vision="",
43 | origin_llm="",
44 | vision_encoder=dict(
45 | name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model
46 | ),
47 | torch_dtype='bfloat16',
48 | freeze_projector=False,
49 | freeze_lm=True,
50 | freeze_vision_tower=True,
51 | lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma
52 | use_lora=True,
53 | lora_r=128,
54 | lora_alpha=32,
55 | lora_dropout=0.05,
56 | num_frames="${num_frames}",
57 | pooling_method='avg',
58 | use_pooling=True,
59 | frame_shape=(24,24),
60 | pooling_shape=(16,8,8),
61 | )
62 | preprocess = dict(
63 | system="",
64 | mm_alone=True,
65 | image_token_index=64002,
66 | random_shuffle=True,
67 | add_second_msg=True,
68 | roles=['<|im_start|>user\n', '<|im_start|>assistant\n'],
69 | end_signal=('<|im_end|>\n', '<|im_end|>\n'),
70 | begin_signal='',
71 | dataset_image_placeholder='',
72 | dataset_video_placeholder='',
73 | max_txt_l = "${max_txt_l}",
74 | ignore_index=-100, # same as torch softmax ignore index
75 | center_pad=False,
76 | longest_edge=762,
77 | shortest_edge=336,
78 | clip_transform=False,
79 | num_frames="${num_frames}",
80 | )
81 |
82 |
83 | optimizer = dict(
84 | opt="adamW",
85 | lr=2e-5,
86 | opt_betas=[0.9, 0.999], # default
87 | weight_decay=0.02,
88 | max_grad_norm=-1, # requires a positive float, use -1 to disable
89 | # use a different lr for some modules, e.g., larger lr for new modules
90 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
91 | )
92 |
93 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
94 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
95 | scheduler = dict(
96 | is_videochat2_custom=False,
97 | sched="cosine",
98 | epochs=2,
99 | warmup_ratio=0.2,
100 | min_lr_multi=0.25)
101 |
102 | evaluate = False
103 | deep_fusion = False
104 | evaluation = dict(
105 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
106 | eval_x_only=False,
107 | k_test=128,
108 | eval_offload=True, # offload gpu tensors to cpu to save memory.
109 | )
110 |
111 | fp16 = True
112 | gradient_checkpointing = True
113 |
114 | # ========================= wandb ==========================
115 | wandb = dict(
116 | enable=False,
117 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
118 | project="videochat2", # setup in your command line
119 | )
120 | dist_url = "env://"
121 | device = "cuda"
122 | mode = "it"
123 |
124 | # ========================= others ==========================
125 | output_dir = None # output dir
126 | resume = False # if True, load optimizer and scheduler states as well
127 | debug = False
128 | log_freq = 5
129 | metric_window_size=10 # window size for metric
130 | seed = 42
131 | report_to='tensorboard'
132 | save_latest = True
133 | auto_resume = True
134 | pretrained_path = "" # path to pretrained model weights, for resume only?
135 |
--------------------------------------------------------------------------------
/tasks/train/ego_process.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | anno_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/vqa/ego_qa/train.json'
3 | video_root_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/ego4d_data/split_videos'
4 |
5 | annos = json.load(open(anno_file, 'r'))
6 | for anno in annos:
7 | video_path = anno['video']
8 | video_path = os.path.join(video_root_path, video_path)
9 | if not os.path.exists(video_path):
10 | print(video_path)
--------------------------------------------------------------------------------
/tasks/train/ffmpeg_tgif.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 |
4 | # 源文件夹路径
5 | source_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/video_gif'
6 | target_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/videos_mp4'
7 | # 包含文件名的文本文件路径
8 | file_list_path = 'not_have.txt'
9 |
10 | if not os.path.exists(target_folder):
11 | os.makedirs(target_folder)
12 |
13 | # 读取文件列表并转换
14 | with open(file_list_path, 'r') as file_list:
15 | for line in file_list:
16 | # 获取去除前后空白字符的文件名
17 | gif_filename = line.strip()
18 | # 源文件完整路径
19 | source_path = os.path.join(source_folder, gif_filename)
20 | # 目标文件完整路径,假设输入文件名格式正确,并将后缀替换为.mp4
21 | target_path = os.path.join(target_folder, os.path.splitext(gif_filename)[0] + '.mp4')
22 |
23 | # 构建ffmpeg命令
24 | cmd = ['ffmpeg', '-i', source_path, '-movflags', 'faststart', target_path]
25 |
26 | # 执行命令
27 | try:
28 | subprocess.run(cmd, check=True)
29 | print(f'Successfully converted {gif_filename} to MP4.')
30 | except subprocess.CalledProcessError as e:
31 | print(f'Failed to convert {gif_filename}. Error: {e}')
32 |
33 | print('All files have been processed.')
--------------------------------------------------------------------------------
/tasks/train/k710_print.py:
--------------------------------------------------------------------------------
1 | dataset_path = {
2 | # 'k400': '/root/paddlejob/workspace/env_run/output/xiaohu/data/k400/train',
3 | # 'k600': '/root/paddlejob/workspace/env_run/output/xiaohu/data/k600/Kinetics600/videos',
4 | # 'k700': '/root/paddlejob/workspace/env_run/data_afs_3/zhouhao14/intern/xiaohu/k700_dir/Kinetics_700/videos/'
5 | 'k710': '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/k710'
6 | }
7 |
8 | import os
9 | from tqdm import tqdm
10 |
11 | f = open('k710_files_filter.txt', 'w')
12 | for dataset, path in dataset_path.items():
13 | # dir_list = os.listdir(path)
14 | # for dir in tqdm(dir_list):
15 | # dir_path = os.path.join(path, dir)
16 | file_list = os.listdir(path)
17 | for file in file_list:
18 | file_path = os.path.join(path, file)
19 | f.write(file+' '+file_path+'\n')
--------------------------------------------------------------------------------
/tasks/train/k710_process.py:
--------------------------------------------------------------------------------
1 | annotation_file = "/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/classification/k710/train_new.json"
2 | annotation_file_new = "/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/classification/k710/train_new_1.json"
3 | file_list = open('k710_files_filter.txt', 'r').readlines()
4 | file_list = [file.strip().split(' ') for file in file_list]
5 | file_dict = {}
6 | for file, path in file_list:
7 | file = file[:11].lower()
8 | file_dict[file] = path
9 | import os
10 | import json
11 |
12 | annotations = json.load(open(annotation_file))
13 | print('annoation length:', len(annotations))
14 | annotations_new = []
15 | count = 0
16 | for anno in annotations:
17 | video_path = anno['video']
18 | video_path = video_path.split('/')[-1].split('.')[0]
19 | if len(video_path) > 15:
20 | video_path = video_path[:11]
21 | video_path = video_path.lower()
22 | if video_path in file_dict:
23 | # anno['video'] = file_dict[video_path.lower()]
24 | anno['video'] = anno['video'].split('/')[-1]
25 | annotations_new.append(anno)
26 | else:
27 | count += 1
28 | json.dump(annotations_new, open(annotation_file_new, 'w'))
29 | print('miss number:', count)
30 | # for file, file_path in file_list:
31 | # if video_path in file:
32 | # continue
33 | # else:
34 | # print(video_path)
--------------------------------------------------------------------------------
/tasks/train/mk_710.py:
--------------------------------------------------------------------------------
1 | annotation_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/classification/k710/train_new.json'
2 | dst_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/k710'
3 |
4 | import os, json, shutil
5 | from tqdm import tqdm
6 |
7 | f = open(annotation_file)
8 | annotations = json.load(f)
9 |
10 | for anno in tqdm(annotations):
11 | video_path = anno['video']
12 | video_name = os.path.basename(video_path)
13 | shutil.copyfile(video_path, dst_path + '/' + video_name)
--------------------------------------------------------------------------------
/tasks/train/not_have.txt:
--------------------------------------------------------------------------------
1 | tumblr_nb2b36uj4V1skspzwo1_250.gif
2 | tumblr_nnwhmq7vx91r76agyo1_250.gif
3 | tumblr_nq92uaE0Ws1u97vumo1_500.gif
4 | tumblr_mx7h7rT5VW1qd80wyo1_400.gif
5 | tumblr_ncfcoj7Cjk1qd80wyo1_400.gif
6 | tumblr_n7kvhniZkO1qd80wyo1_400.gif
7 | tumblr_nk56bs75dd1u7oomho1_400.gif
8 | tumblr_nk8cx74VZI1r88jv8o1_250.gif
9 | tumblr_npkvwxAmBw1ux8xe0o1_250.gif
10 | tumblr_nnwhsh94Cz1r76agyo1_100.gif
11 | tumblr_nqrmt7MmEi1ux8xe0o1_400.gif
12 | tumblr_nauyg863cl1tdjuqvo1_400.gif
13 | tumblr_n9gq572Eil1qd80wyo1_400.gif
14 | tumblr_np7f9w4gb61s4vkvgo1_250.gif
15 | tumblr_naemimnRQ21qj7ohio1_500.gif
16 | tumblr_npu4nvnG8y1ux8xe0o1_250.gif
17 | tumblr_ne372wjN501tmgpxuo1_250.gif
18 | tumblr_nqo7ly0WTQ1sgafh8o1_400.gif
19 | tumblr_nc669vjChQ1s7nakbo1_400.gif
20 | tumblr_njs0cnWbe11tgetb4o1_250.gif
21 | tumblr_nkils8vvnN1tk2dvro1_400.gif
22 | tumblr_n9vlgvJRfR1qd80wyo1_400.gif
23 | tumblr_n068ybhVsN1rkm4f7o1_400.gif
24 | tumblr_n3v3xrtaGc1r8go1ao1_250.gif
25 | tumblr_niibjr3cfO1u8uroco1_250.gif
26 | tumblr_nfa9ofxIIb1sk96t7o1_400.gif
27 | tumblr_nbjg4bukeX1raaknro1_250.gif
28 | tumblr_n8qummpIqS1sfcnmao1_250.gif
29 | tumblr_nnwk6xiIOh1r76agyo1_250.gif
30 | tumblr_n92cixcsvI1r88jv8o1_400.gif
31 | tumblr_nnwhoxZmzr1r76agyo1_100.gif
32 | tumblr_ncfslbyWaf1trjw2xo1_400.gif
33 | tumblr_nofiupp70K1tsywajo1_250.gif
34 | tumblr_nq3gnhRRgi1r09l2vo1_400.gif
35 | tumblr_naohj0KLo81sw0250o1_400.gif
36 | tumblr_nr13cnatIH1ux8xe0o1_250.gif
37 | tumblr_np8vddxRS81uw8t6bo1_400.gif
38 | tumblr_mtg3j27bd61s7nakbo1_400.gif
39 | tumblr_nh9ey9djGZ1s7nakbo1_400.gif
40 | tumblr_ngyd0yHHkR1s6jpovo1_400.gif
41 | tumblr_nk8tjfoZ5v1u7oomho1_400.gif
42 | tumblr_nf7mvh6bsr1qd80wyo1_400.gif
43 | tumblr_marc0jPgsb1qkq2eno1_400.gif
44 |
--------------------------------------------------------------------------------
/tasks/train/output.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/train/output.mp4
--------------------------------------------------------------------------------
/tasks/train/print_all_files.py:
--------------------------------------------------------------------------------
1 | data_root = '/root/paddlejob/workspace/env_run/data_afs_3/zhouhao14/intern/xiaohu/webvid/webvid'
2 |
3 | import os
4 | from tqdm import tqdm
5 | f = open('webvid_list.txt', 'w')
6 | dir_list = os.listdir(data_root)
7 | for dir in tqdm(dir_list):
8 | dir_path = os.path.join(data_root, dir)
9 | file_list = os.listdir(dir_path)
10 | for file in file_list:
11 | file_path = os.path.join(dir_path, file)
12 | f.write(dir +'/'+ file + '\n')
--------------------------------------------------------------------------------
/tasks/train/tgif_corrupt.txt:
--------------------------------------------------------------------------------
1 | tumblr_nkps43Kmm31unykvpo1_400.mp4
2 | tumblr_nhsbt0k1TK1rqx3tso1_500.mp4
3 | tumblr_nqidiuBkaX1tpg4boo1_250.mp4
4 | tumblr_nd8ikzEuRs1r4tjm5o1_400.mp4
5 | tumblr_no5qys6Zv91so4o4wo1_400.mp4
6 | tumblr_ne5vc0elB31slj978o1_400.mp4
7 | tumblr_nocmsxthAQ1tsqdy0o1_400.mp4
8 | tumblr_nqrdphoLSG1t95h1uo1_250.mp4
9 | tumblr_n9kbl6J6Tg1tgy7r4o1_400.mp4
10 | tumblr_nl0eirZXIH1tm778fo1_400.mp4
11 | tumblr_npoca1RD9h1s71nvbo1_400.mp4
12 | tumblr_nqkw6wcy7u1tg815ro1_400.mp4
13 | tumblr_nhglxePcKs1u713vko1_250.mp4
14 | tumblr_mlxk44iKda1rlryi1o1_400.mp4
15 | tumblr_nc7ou88Xda1tdmffyo1_250.mp4
16 | tumblr_npax48inl61tf42s3o1_400.mp4
17 | tumblr_n8uzopnOQA1tbgcpko1_400.mp4
18 | tumblr_njn1lj3sCB1unrob4o1_500.mp4
19 | tumblr_nozp1udlit1t95h1uo1_250.mp4
20 | tumblr_nmierrf5ZH1tnos68o1_250.mp4
21 | tumblr_nfm47yIKUE1rtequ6o1_400.mp4
22 | tumblr_navaivague1re06l8o1_400.mp4
23 | tumblr_nriiqfSIWQ1uaoehqo1_400.mp4
24 | tumblr_nfnj91Sxrc1tv4d9wo1_250.mp4
25 | tumblr_n8o5meR3HR1te77izo1_400.mp4
26 | tumblr_nq5yq0EmZB1u8gd00o1_250.mp4
27 | tumblr_nb2ol4NniN1tkpzw0o1_250.mp4
28 | tumblr_nh04cdAUje1sm9b1po2_400.mp4
29 | tumblr_nnjavclA4f1utipxro1_250.mp4
30 | tumblr_mv3ld4n2ri1sfprkzo1_400.mp4
31 | tumblr_nn91q1Yoa01uqzp8co1_500.mp4
32 | tumblr_ncl7lvFWpg1u04f66o1_400.mp4
33 | tumblr_nkwlv8Hg0U1qfq2gno1_400.mp4
34 | tumblr_no28osQSix1twfmf3o1_250.mp4
35 | tumblr_ne3958o3xv1qhrx75o1_250.mp4
36 | tumblr_nav9toudpf1re06l8o1_400.mp4
37 | tumblr_nfbuufZpod1u4068wo1_400.mp4
38 | tumblr_nonucoHIFu1tpg4boo1_250.mp4
39 | tumblr_naf1akCsYV1ts0kzio1_400.mp4
40 | tumblr_na6j8jESjb1thpigwo1_400.mp4
41 | tumblr_nkz4w2utsF1sm7eoto1_400.mp4
42 | tumblr_nigio1CXBZ1u8uroco1_400.mp4
43 | tumblr_noiq0clHCg1qzhjh2o1_400.mp4
44 | tumblr_nh74fcyMFu1slj978o1_250.mp4
45 | tumblr_nfyul0NhhS1tzs6b2o1_500.mp4
46 | tumblr_nf2oli6DJm1slw55qo1_400.mp4
47 | tumblr_npmwjgWiIA1uvie7bo1_400.mp4
48 | tumblr_nkq2vhomVm1twnkudo1_r1_400.mp4
49 | tumblr_niubadIAbL1tqviovo1_500.mp4
50 | tumblr_np3dvwowfN1up68h4o1_500.mp4
51 | tumblr_ngssnyMJqn1slj978o1_400.mp4
52 | tumblr_nfyqdyhWDn1sx7xv7o1_500.mp4
53 | tumblr_noe054KxAw1tpg4boo1_400.mp4
54 | tumblr_npj2p3o4361tx8mn0o1_400.mp4
55 | tumblr_ncv7bmm1lG1tf01j4o1_250.mp4
56 | tumblr_nlzaj75aY51s85u2fo1_500.mp4
57 | tumblr_mvz12eJxAB1rbf9bno1_500.mp4
58 | tumblr_n9oltoBCjr1t0ohh1o1_500.mp4
59 | tumblr_nm128eKJ7n1r9yho8o1_540.mp4
60 | tumblr_nejwjzrE2I1spote4o1_500.mp4
61 | tumblr_nrfz25aGMZ1ual9cno1_250.mp4
62 | tumblr_npdasiJEUc1sht3fmo1_400.mp4
63 | tumblr_njvhuklV531u2muk4o1_400.mp4
64 | tumblr_no566xiPbF1ttvor4o1_500.mp4
65 | tumblr_nd3sy96uW81tdvc4qo1_400.mp4
66 | tumblr_ne9hgkuZlD1twfpc5o1_500.mp4
67 | tumblr_m6qbi2ZSEE1qj7lb4o1_r3_500.mp4
68 | tumblr_naysx8YTzn1tzl1owo1_400.mp4
69 | tumblr_nk7hxlH63e1u2b31do1_400.mp4
70 | tumblr_nb6ejhQiUn1sl27r8o1_250.mp4
71 | tumblr_nqih1fBKzz1r0s2r6o1_250.mp4
72 | tumblr_nkb5fbXoNb1syz358o1_400.mp4
73 | tumblr_ngns8uQo831rq9gtvo1_400.mp4
74 | tumblr_nf9tyzpthC1tmddexo1_250.mp4
75 | tumblr_na7nn7XEI41rd6gi7o1_500.mp4
76 | tumblr_nknuza4Res1r3mh0to1_400.mp4
77 | tumblr_nfoxrvvztQ1tk8ub5o1_400.mp4
78 | tumblr_npzl3w7VUC1spbq2fo1_400.mp4
79 | tumblr_nlgaudRZPW1rw95g7o1_400.mp4
80 | tumblr_nnzysjVJR51r59fn4o1_400.mp4
81 | tumblr_nbv7fwPUde1shxl87o1_250.mp4
82 | tumblr_ni34ysM1ri1u3ztwyo1_250.mp4
83 | tumblr_nhq0b5Ukbk1tcof18o1_500.mp4
84 | tumblr_nqmswiqk121r8h6u4o1_500.mp4
85 | tumblr_n8z70012w81tgbvgqo1_500.mp4
86 | tumblr_nn61aaKcNA1qhqb9no1_r1_500.mp4
87 | tumblr_npjuc3Hc9r1qhrx75o1_400.mp4
88 |
--------------------------------------------------------------------------------
/tasks/train/tgif_mp4.py:
--------------------------------------------------------------------------------
1 | import os
2 | from moviepy.editor import VideoFileClip
3 | from concurrent.futures import ThreadPoolExecutor, as_completed
4 | from tqdm import tqdm
5 | from decord import VideoReader
6 | from decord import cpu, gpu
7 |
8 |
9 | # 源文件夹和目标文件夹
10 | source_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/video_gif'
11 | target_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/videos_mp4'
12 | used_list = 'tgif_used.txt'
13 | corrupt_list = 'tgif_corrupt.txt'
14 | f_corrupt = open(corrupt_list, 'w')
15 | # 读取使用的文件列表
16 | with open(used_list, 'r') as f:
17 | used_files = [line.strip() for line in f.readlines()]
18 |
19 | count = 0
20 | not_count = 0
21 | miss_list = []
22 | for file in tqdm(used_files):
23 | # file_new = file.replace('.gif', '.mp4')
24 | video_path = os.path.join(target_folder, file)
25 | if os.path.exists(video_path):
26 | try:
27 | vr = VideoReader(video_path, ctx=cpu(0))
28 | except:
29 | count += 1
30 | miss_list.append(file)
31 | print('Error processing {}'.format(file))
32 | f_corrupt.write(file + '\n')
33 | else:
34 | not_count += 1
35 | print(count, not_count)
36 | # print(os.path.join(source_folder, file))
37 | # print(count, len(used_files))
38 | # miss_list = []
39 |
40 | # def convert_gif_to_mp4(file):
41 | # if file.endswith('.gif'):
42 | # source_path = os.path.join(source_folder, file)
43 | # target_path = os.path.join(target_folder, file.replace('.gif', '.mp4'))
44 |
45 | # if os.path.exists(target_path):
46 | # return f'{file} already converted. Skipping...'
47 | # try:
48 | # clip = VideoFileClip(source_path)
49 | # clip.write_videofile(target_path, codec="libx264", fps=24)
50 | # clip.close()
51 |
52 | # return f'Saved {file.replace(".gif", ".mp4")} to {target_folder}'
53 | # except Exception as e:
54 | # print('Error processing {}'.format(file), e, sep='\n')
55 | # # raise e
56 | # return f'Error processing {file}: {e}'
57 | # else:
58 | # return f'{file} is not a GIF. Skipping...'
59 |
60 | # # 设置线程池的最大线程数
61 | # max_threads = 8
62 |
63 | # with ThreadPoolExecutor(max_workers=max_threads) as executor:
64 | # # 使用executor.map来并行处理任务
65 | # # 注意:如果你想在任务执行时保持进度条更新,可能需要使用executor.submit和as_completed
66 | # futures = [executor.submit(convert_gif_to_mp4, file) for file in used_files]
67 |
68 | # # 为了展示进度条,我们使用as_completed来获取已完成的future
69 | # for future in tqdm(as_completed(futures), total=len(futures)):
70 | # print(future.result())
71 |
72 | # print("所有视频处理完成!")
--------------------------------------------------------------------------------
/tasks/train/vcg_process.py:
--------------------------------------------------------------------------------
1 | anno_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/conversation/videochatgpt/train_new.json'
2 | anno_new_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/conversation/videochatgpt/train_new_1.json'
3 |
4 | data_root = '//root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/panda'
5 | import os, json
6 | from decord import VideoReader
7 | from decord import cpu, gpu
8 | from tqdm import tqdm
9 |
10 |
11 | miss_count = 0
12 | count = 0
13 | # annos = json.load(open(anno_file))
14 | # annos_new = []
15 | # for anno in tqdm(annos):
16 | # video_path = os.path.join(data_root, anno['video'])
17 | # if not os.path.exists(video_path):
18 | # continue
19 | # try:
20 | # vr = VideoReader(video_path, ctx=cpu(0))
21 | # annos_new.append(anno)
22 | # except:
23 | # count += 1
24 | # json.dump(annos_new, open(anno_new_file, 'w'))
25 | # print(count)
26 |
27 | files = os.listdir(data_root)
28 | for file in tqdm(files):
29 | video_path = os.path.join(data_root, file)
30 | try:
31 | count += 1
32 | vr = VideoReader(video_path, ctx=cpu(0))
33 | except:
34 | miss_count += 1
35 | print(video_path)
36 | print(miss_count, count)
--------------------------------------------------------------------------------
/tasks/train/vcg_read.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | # 视频文件路径
4 | video_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/Video_ChatGPT/v_Z0eBz6QsI-c.mp4'
5 |
6 | # 打开视频文件
7 | cap = cv2.VideoCapture(video_path)
8 |
9 | while cap.isOpened():
10 | # 读取一帧
11 | ret, frame = cap.read()
12 |
13 | # 如果正确读取帧,ret为True
14 | if not ret:
15 | print("Can't receive frame (stream end?). Exiting ...")
16 | break
17 |
18 | # 显示当前帧
19 | # cv2.imshow('frame', frame)
20 |
21 | # 按 'q' 退出
22 | if cv2.waitKey(1) == ord('q'):
23 | break
24 |
25 | # 释放Capture对象
26 | cap.release()
27 | cv2.destroyAllWindows()
--------------------------------------------------------------------------------
/tasks/train/webvid_process.py:
--------------------------------------------------------------------------------
1 | file_list = 'webvid_list.txt'
2 | anno_root_it = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons'
3 | train_list_files = [f"{anno_root_it}/video/caption/videochat/train.json", f"{anno_root_it}/video/caption/webvid/train.json", f"{anno_root_it}/video/conversation/videochat1/train.json", f"{anno_root_it}/video/vqa/webvid_qa/train.json"]
4 |
5 | import os
6 | import json
7 | from tqdm import tqdm
8 |
9 | files = open(file_list).readlines()
10 | files = [file.strip() for file in files]
11 | f_missed = open('missing_files.txt', 'w')
12 | for file in train_list_files:
13 | f = open(file, 'r')
14 | item_list = []
15 | data = json.load(f)
16 | for item in tqdm(data):
17 | video_id = item['video']
18 | if video_id not in files:
19 | f_missed.write(video_id + '\n')
20 | else:
21 | item_list.append(item)
22 | new_file = file.replace('train', 'train_new')
23 | with open(new_file, 'w') as f_new:
24 | json.dump(item_list, f_new)
--------------------------------------------------------------------------------
/utils/__pycache__/basic_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/basic_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/config_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/distributed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/distributed.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/easydict.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/easydict.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/logger.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/optimizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/optimizer.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/scheduler.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/scheduler.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/basic_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import io
3 | import os
4 | import json
5 | import logging
6 | import random
7 | import time
8 | from collections import defaultdict, deque
9 | import datetime
10 | from pathlib import Path
11 | from typing import List, Union
12 |
13 | import torch
14 | import torch.distributed as dist
15 | from .distributed import is_dist_avail_and_initialized
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class SmoothedValue(object):
22 | """Track a series of values and provide access to smoothed values over a
23 | window or the global series average.
24 | """
25 |
26 | def __init__(self, window=20, fmt=None):
27 | if fmt is None:
28 | fmt = "{median:.4f} ({global_avg:.4f})"
29 | self.deque = deque(maxlen=window)
30 | self.total = 0.0
31 | self.count = 0
32 | self.fmt = fmt
33 |
34 | def update(self, value, n=1):
35 | self.deque.append(value)
36 | self.count += n
37 | self.total += value * n
38 |
39 | def synchronize_between_processes(self):
40 | """
41 | Warning: does not synchronize the deque!
42 | """
43 | if not is_dist_avail_and_initialized():
44 | return
45 | t = torch.tensor([self.count, self.total],
46 | dtype=torch.float64, device='cuda')
47 | dist.barrier()
48 | dist.all_reduce(t)
49 | t = t.tolist()
50 | self.count = int(t[0])
51 | self.total = t[1]
52 |
53 | @property
54 | def median(self):
55 | d = torch.tensor(list(self.deque))
56 | return d.median().item()
57 |
58 | @property
59 | def avg(self):
60 | d = torch.tensor(list(self.deque), dtype=torch.float32)
61 | return d.mean().item()
62 |
63 | @property
64 | def global_avg(self):
65 | return self.total / self.count
66 |
67 | @property
68 | def max(self):
69 | return max(self.deque)
70 |
71 | @property
72 | def value(self):
73 | return self.deque[-1]
74 |
75 | def __str__(self):
76 | return self.fmt.format(
77 | median=self.median,
78 | avg=self.avg,
79 | global_avg=self.global_avg,
80 | max=self.max,
81 | value=self.value)
82 |
83 |
84 | class MetricLogger(object):
85 | def __init__(self, delimiter="\t"):
86 | self.meters = defaultdict(SmoothedValue)
87 | self.delimiter = delimiter
88 |
89 | def update(self, **kwargs):
90 | for k, v in kwargs.items():
91 | if isinstance(v, torch.Tensor):
92 | v = v.item()
93 | assert isinstance(v, (float, int))
94 | self.meters[k].update(v)
95 |
96 | def __getattr__(self, attr):
97 | if attr in self.meters:
98 | return self.meters[attr]
99 | if attr in self.__dict__:
100 | return self.__dict__[attr]
101 | raise AttributeError("'{}' object has no attribute '{}'".format(
102 | type(self).__name__, attr))
103 |
104 | def __str__(self):
105 | loss_str = []
106 | for name, meter in self.meters.items():
107 | if meter.count == 0: # skip empty meter
108 | loss_str.append(
109 | "{}: {}".format(name, "No data")
110 | )
111 | else:
112 | loss_str.append(
113 | "{}: {}".format(name, str(meter))
114 | )
115 | return self.delimiter.join(loss_str)
116 |
117 | def global_avg(self):
118 | loss_str = []
119 | for name, meter in self.meters.items():
120 | if meter.count == 0:
121 | loss_str.append(
122 | "{}: {}".format(name, "No data")
123 | )
124 | else:
125 | loss_str.append(
126 | "{}: {:.4f}".format(name, meter.global_avg)
127 | )
128 | return self.delimiter.join(loss_str)
129 |
130 | def get_global_avg_dict(self, prefix=""):
131 | """include a separator (e.g., `/`, or "_") at the end of `prefix`"""
132 | d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()}
133 | return d
134 |
135 | def synchronize_between_processes(self):
136 | for meter in self.meters.values():
137 | meter.synchronize_between_processes()
138 |
139 | def add_meter(self, name, meter):
140 | self.meters[name] = meter
141 |
142 | def log_every(self, iterable, log_freq, header=None):
143 | i = 0
144 | if not header:
145 | header = ''
146 | start_time = time.time()
147 | end = time.time()
148 | iter_time = SmoothedValue(fmt='{avg:.4f}')
149 | data_time = SmoothedValue(fmt='{avg:.4f}')
150 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
151 | log_msg = [
152 | header,
153 | '[{0' + space_fmt + '}/{1}]',
154 | 'eta: {eta}',
155 | '{meters}',
156 | 'time: {time}',
157 | 'data: {data}'
158 | ]
159 | if torch.cuda.is_available():
160 | log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}')
161 | log_msg = self.delimiter.join(log_msg)
162 | MB = 1024.0 * 1024.0
163 | for obj in iterable:
164 | data_time.update(time.time() - end)
165 | yield obj
166 | iter_time.update(time.time() - end)
167 | if i % log_freq == 0 or i == len(iterable) - 1:
168 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
169 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
170 | if torch.cuda.is_available():
171 | logger.info(log_msg.format(
172 | i, len(iterable), eta=eta_string,
173 | meters=str(self),
174 | time=str(iter_time), data=str(data_time),
175 | memory=torch.cuda.max_memory_allocated() / MB,
176 | res_mem=torch.cuda.max_memory_reserved() / MB,
177 | ))
178 | else:
179 | logger.info(log_msg.format(
180 | i, len(iterable), eta=eta_string,
181 | meters=str(self),
182 | time=str(iter_time), data=str(data_time)))
183 | i += 1
184 | end = time.time()
185 | total_time = time.time() - start_time
186 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
187 | logger.info('{} Total time: {} ({:.4f} s / it)'.format(
188 | header, total_time_str, total_time / len(iterable)))
189 |
190 |
191 | class AttrDict(dict):
192 | def __init__(self, *args, **kwargs):
193 | super(AttrDict, self).__init__(*args, **kwargs)
194 | self.__dict__ = self
195 |
196 |
197 | def compute_acc(logits, label, reduction='mean'):
198 | ret = (torch.argmax(logits, dim=1) == label).float()
199 | if reduction == 'none':
200 | return ret.detach()
201 | elif reduction == 'mean':
202 | return ret.mean().item()
203 |
204 |
205 | def compute_n_params(model, return_str=True):
206 | tot = 0
207 | for p in model.parameters():
208 | w = 1
209 | for x in p.shape:
210 | w *= x
211 | tot += w
212 | if return_str:
213 | if tot >= 1e6:
214 | return '{:.1f}M'.format(tot / 1e6)
215 | else:
216 | return '{:.1f}K'.format(tot / 1e3)
217 | else:
218 | return tot
219 |
220 |
221 | def setup_seed(seed):
222 | torch.manual_seed(seed)
223 | np.random.seed(seed)
224 | random.seed(seed)
225 |
226 |
227 | def remove_files_if_exist(file_paths):
228 | for fp in file_paths:
229 | if os.path.isfile(fp):
230 | os.remove(fp)
231 |
232 |
233 | def save_json(data, filename, save_pretty=False, sort_keys=False):
234 | with open(filename, "w") as f:
235 | if save_pretty:
236 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
237 | else:
238 | json.dump(data, f)
239 |
240 |
241 | def load_json(filename):
242 | with open(filename, "r") as f:
243 | return json.load(f)
244 |
245 |
246 | def flat_list_of_lists(l):
247 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
248 | return [item for sublist in l for item in sublist]
249 |
250 |
251 | def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]):
252 | """
253 | Args:
254 | root: path to the directory to start search files
255 | suffix: any str as suffix, or can match multiple such strings
256 | when input is List[str].
257 | Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`]
258 | Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`.
259 | """
260 | if isinstance(suffix, str):
261 | suffix = [suffix, ]
262 | filepaths = flat_list_of_lists(
263 | [list(Path(root).rglob(f"*{e}")) for e in suffix])
264 | return filepaths
265 |
266 |
267 | def match_key_and_shape(state_dict1, state_dict2):
268 | keys1 = set(state_dict1.keys())
269 | keys2 = set(state_dict2.keys())
270 | print(f"keys1 - keys2: {keys1 - keys2}")
271 | print(f"keys2 - keys1: {keys2 - keys1}")
272 |
273 | mismatch = 0
274 | for k in list(keys1):
275 | if state_dict1[k].shape != state_dict2[k].shape:
276 | print(
277 | f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}")
278 | mismatch += 1
279 | print(f"mismatch {mismatch}")
280 |
281 |
282 | def merge_dicts(list_dicts):
283 | merged_dict = list_dicts[0].copy()
284 | for i in range(1, len(list_dicts)):
285 | merged_dict.update(list_dicts[i])
286 | return merged_dict
287 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import argparse
4 | import ast
5 | import json
6 | import os
7 | import os.path as osp
8 | import re
9 | import shutil
10 | import sys
11 | import tempfile
12 | from copy import deepcopy
13 | from importlib import import_module
14 |
15 | import yaml
16 |
17 | from .easydict import EasyDict
18 |
19 | __all__ = ["Config", "pretty_text"]
20 |
21 |
22 | BASE_KEY = "_base_"
23 | # BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
24 | BASE_CONFIG = {}
25 |
26 | cfg = None
27 |
28 |
29 | class Config(object):
30 | """config"""
31 |
32 | @classmethod
33 | def pretty_text(cls, cfg: dict, indent=2) -> str:
34 | """format dict to a string
35 |
36 | Args:
37 | cfg (EasyDict): the params.
38 |
39 | Returns: The string to display.
40 |
41 | """
42 | msg = "{\n"
43 | for i, (k, v) in enumerate(cfg.items()):
44 | if isinstance(v, dict):
45 | v = cls.pretty_text(v, indent + 4)
46 | spaces = " " * indent
47 | msg += spaces + "{}: {}".format(k, v)
48 | if i == len(cfg) - 1:
49 | msg += " }"
50 | else:
51 | msg += "\n"
52 | return msg
53 |
54 | @classmethod
55 | def dump(cls, cfg, savepath=None):
56 | """dump cfg to `json` file.
57 |
58 | Args:
59 | cfg (dict): The dict to dump.
60 | savepath (str): The filepath to save the dumped dict.
61 |
62 | Returns: TODO
63 |
64 | """
65 | if savepath is None:
66 | savepath = osp.join(cfg.WORKSPACE, "config.json")
67 | json.dump(cfg, open(savepath, "w"), indent=2)
68 |
69 | @classmethod
70 | def get_config(cls, default_config: dict = None):
71 | """get a `Config` instance.
72 |
73 | Args:
74 | default_config (dict): The default config. `default_config` will be overrided
75 | by config file `--cfg`, `--cfg` will be overrided by commandline args.
76 |
77 | Returns: an EasyDict.
78 | """
79 | global cfg
80 | if cfg is not None:
81 | return cfg
82 |
83 | # define arg parser.
84 | parser = argparse.ArgumentParser()
85 | # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
86 | parser.add_argument(
87 | "config_file", help="the configuration file to load. support: .yaml, .json, .py"
88 | )
89 | parser.add_argument(
90 | "opts",
91 | default=None,
92 | nargs="*",
93 | help="overrided configs. List. Format: 'key1 name1 key2 name2'",
94 | )
95 | args = parser.parse_args()
96 |
97 | cfg = EasyDict(BASE_CONFIG)
98 | if osp.isfile(args.config_file):
99 | cfg_from_file = cls.from_file(args.config_file)
100 | cfg = merge_a_into_b(cfg_from_file, cfg)
101 | cfg = cls.merge_list(cfg, args.opts)
102 | cfg = eval_dict_leaf(cfg)
103 |
104 | # update some keys to make them show at the last
105 | for k in BASE_CONFIG:
106 | cfg[k] = cfg.pop(k)
107 | return cfg
108 |
109 | @classmethod
110 | def from_file(cls, filepath: str) -> EasyDict:
111 | """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
112 |
113 | Args:
114 | filepath (str): The config file path.
115 |
116 | Returns: TODO
117 |
118 | """
119 | filepath = osp.abspath(osp.expanduser(filepath))
120 | if not osp.isfile(filepath):
121 | raise IOError(f"File does not exist: {filepath}")
122 | if filepath.endswith(".py"):
123 | with tempfile.TemporaryDirectory() as temp_config_dir:
124 |
125 | shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config"))
126 | sys.path.insert(0, temp_config_dir)
127 | mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0])
128 | # mod = import_module(temp_module_name)
129 | sys.path.pop(0)
130 | cfg_dict = {
131 | name: value
132 | for name, value in mod.__dict__.items()
133 | if not name.startswith("__")
134 | }
135 | for k in list(sys.modules.keys()):
136 | if "tmp_config" in k:
137 | del sys.modules[k]
138 | elif filepath.endswith((".yml", ".yaml")):
139 | cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
140 | elif filepath.endswith(".json"):
141 | cfg_dict = json.load(open(filepath, "r"))
142 | else:
143 | raise IOError("Only py/yml/yaml/json type are supported now!")
144 |
145 | cfg_text = filepath + "\n"
146 | with open(filepath, "r") as f:
147 | cfg_text += f.read()
148 |
149 | if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
150 | cfg_dir = osp.dirname(filepath)
151 | base_filename = cfg_dict.pop(BASE_KEY)
152 | base_filename = (
153 | base_filename if isinstance(base_filename, list) else [base_filename]
154 | )
155 |
156 | cfg_dict_list = list()
157 | for f in base_filename:
158 | _cfg_dict = Config.from_file(osp.join(cfg_dir, f))
159 | cfg_dict_list.append(_cfg_dict)
160 |
161 | base_cfg_dict = dict()
162 | for c in cfg_dict_list:
163 | if len(base_cfg_dict.keys() & c.keys()) > 0:
164 | raise KeyError("Duplicate key is not allowed among bases")
165 | base_cfg_dict.update(c)
166 |
167 | cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
168 |
169 | return EasyDict(cfg_dict)
170 |
171 | @classmethod
172 | def merge_list(cls, cfg, opts: list):
173 | """merge commandline opts.
174 |
175 | Args:
176 | cfg: (dict): The config to be merged.
177 | opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
178 | The keys can be nested. For example, ["a.b", v] will be considered
179 | as `dict(a=dict(b=v))`.
180 |
181 | Returns: dict.
182 |
183 | """
184 | assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
185 | for i in range(0, len(opts), 2):
186 | full_k, v = opts[i], opts[i + 1]
187 | keys = full_k.split(".")
188 | sub_d = cfg
189 | for i, k in enumerate(keys):
190 | if not hasattr(sub_d, k):
191 | raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
192 | if i != len(keys) - 1:
193 | sub_d = sub_d[k]
194 | else:
195 | sub_d[k] = v
196 | return cfg
197 |
198 |
199 | def merge_a_into_b(a, b, inplace=False):
200 | """The values in a will override values in b.
201 |
202 | Args:
203 | a (dict): source dict.
204 | b (dict): target dict.
205 |
206 | Returns: dict. recursively merge dict a into dict b.
207 |
208 | """
209 | if not inplace:
210 | b = deepcopy(b)
211 | for key in a:
212 | if key in b:
213 | if isinstance(a[key], dict) and isinstance(b[key], dict):
214 | b[key] = merge_a_into_b(a[key], b[key], inplace=True)
215 | else:
216 | b[key] = a[key]
217 | else:
218 | b[key] = a[key]
219 | return b
220 |
221 |
222 | def eval_dict_leaf(d, orig_dict=None):
223 | """eval values of dict leaf.
224 |
225 | Args:
226 | d (dict): The dict to eval.
227 |
228 | Returns: dict.
229 |
230 | """
231 | if orig_dict is None:
232 | orig_dict = d
233 | for k, v in d.items():
234 | if not isinstance(v, dict):
235 | d[k] = eval_string(v, orig_dict)
236 | else:
237 | eval_dict_leaf(v, orig_dict)
238 | return d
239 |
240 |
241 | def eval_string(string, d):
242 | """automatically evaluate string to corresponding types.
243 |
244 | For example:
245 | not a string -> return the original input
246 | '0' -> 0
247 | '0.2' -> 0.2
248 | '[0, 1, 2]' -> [0,1,2]
249 | 'eval(1+2)' -> 3
250 | 'eval(range(5))' -> [0,1,2,3,4]
251 | '${a}' -> d.a
252 |
253 |
254 |
255 | Args:
256 | string (str): The value to evaluate.
257 | d (dict): The
258 |
259 | Returns: the corresponding type
260 |
261 | """
262 | if not isinstance(string, str):
263 | return string
264 | # if len(string) > 1 and string[0] == "[" and string[-1] == "]":
265 | # return eval(string)
266 | if string[0:5] == "eval(":
267 | return eval(string[5:-1])
268 |
269 | s0 = string
270 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
271 | if s1 != s0:
272 | while s1 != s0:
273 | s0 = s1
274 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
275 | return eval(s1)
276 |
277 | try:
278 | v = ast.literal_eval(string)
279 | except:
280 | v = string
281 | return v
282 |
--------------------------------------------------------------------------------
/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from os.path import dirname, join
5 |
6 | from utils.config import Config
7 | from utils.distributed import init_distributed_mode, is_main_process
8 | from utils.logger import setup_logger
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def setup_config():
14 | """Conbine yaml config and command line config with OmegaConf.
15 | Also converts types, e.g., `'None'` (str) --> `None` (None)
16 | """
17 | config = Config.get_config()
18 | if config.debug:
19 | config.wandb.enable = False
20 | return config
21 |
22 |
23 | def setup_evaluate_config(config):
24 | """setup evaluation default settings, e.g., disable wandb"""
25 | assert config.evaluate
26 | config.wandb.enable = False
27 | if config.output_dir is None:
28 | config.output_dir = join(dirname(config.pretrained_path), "eval")
29 | return config
30 |
31 |
32 | def setup_output_dir(output_dir, excludes=["code"]):
33 | """ensure not overwritting an exisiting/non-empty output dir"""
34 | if not os.path.exists(output_dir):
35 | os.makedirs(output_dir, exist_ok=False)
36 | else:
37 | existing_dirs_files = os.listdir(output_dir) # list
38 | remaining = set(existing_dirs_files) - set(excludes)
39 | remaining = [e for e in remaining if "slurm" not in e]
40 | remaining = [e for e in remaining if ".out" not in e]
41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
42 | logger.warn(f"remaining dirs or files: {remaining}")
43 |
44 |
45 | def setup_main():
46 | """
47 | Setup config, logger, output_dir, etc.
48 | Shared for pretrain and all downstream tasks.
49 | """
50 | config = setup_config()
51 | if hasattr(config, "evaluate") and config.evaluate:
52 | config = setup_evaluate_config(config)
53 | init_distributed_mode(config)
54 |
55 | if is_main_process():
56 | setup_output_dir(config.output_dir, excludes=["code"])
57 | setup_logger(output=config.output_dir, color=True, name="vindlu")
58 | logger.info(f"config: {Config.pretty_text(config)}")
59 | Config.dump(config, os.path.join(config.output_dir, "config.json"))
60 | return config
61 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | import logging
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def setup_for_distributed(is_master):
11 | import warnings
12 |
13 | builtin_warn = warnings.warn
14 |
15 | def warn(*args, **kwargs):
16 | force = kwargs.pop("force", False)
17 | if is_master or force:
18 | builtin_warn(*args, **kwargs)
19 |
20 | # Log warnings only once
21 | warnings.warn = warn
22 | warnings.simplefilter("once", UserWarning)
23 |
24 | if not is_master:
25 | logging.disable()
26 |
27 |
28 | def is_dist_avail_and_initialized():
29 | if not dist.is_available():
30 | return False
31 | if not dist.is_initialized():
32 | return False
33 | return True
34 |
35 |
36 | def get_world_size():
37 | if not is_dist_avail_and_initialized():
38 | return 1
39 | return dist.get_world_size()
40 |
41 |
42 | def get_rank():
43 | if not is_dist_avail_and_initialized():
44 | return 0
45 | return dist.get_rank()
46 |
47 |
48 | def is_main_process():
49 | return get_rank() == 0
50 |
51 |
52 | def save_on_master(*args, **kwargs):
53 | if is_main_process():
54 | torch.save(*args, **kwargs)
55 |
56 |
57 | def is_port_in_use(port):
58 | import socket
59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60 | return s.connect_ex(('localhost', port)) == 0
61 |
62 |
63 | def init_distributed_mode(args):
64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
65 | # job started by torch.distributed.launch
66 | args.rank = int(os.environ["RANK"])
67 | args.world_size = int(os.environ['WORLD_SIZE'])
68 | args.gpu = int(os.environ['LOCAL_RANK'])
69 | elif 'SLURM_PROCID' in os.environ:
70 | # local rank on the current node / global rank
71 | local_rank = int(os.environ['SLURM_LOCALID'])
72 | global_rank = int(os.environ['SLURM_PROCID'])
73 | # number of processes / GPUs per node
74 | world_size = int(os.environ["SLURM_NNODES"]) * \
75 | int(os.environ["SLURM_TASKS_PER_NODE"][0])
76 |
77 | print(world_size)
78 |
79 | args.rank = global_rank
80 | args.gpu = local_rank
81 | args.world_size = world_size
82 | else:
83 | logger.info('Not using distributed mode')
84 | args.distributed = False
85 | return
86 |
87 | args.distributed = True
88 |
89 | torch.cuda.set_device(args.gpu)
90 | args.dist_backend = 'nccl'
91 |
92 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
93 | dist_port = int(args.dist_url.split(":")[-1])
94 | while is_port_in_use(dist_port):
95 | dist_port += 10
96 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
97 | print(args.dist_url)
98 |
99 | logger.info('| distributed init (rank {}): {}'.format(
100 | args.rank, args.dist_url))
101 | if "SLURM_JOB_ID" in os.environ:
102 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
103 | torch.distributed.init_process_group(
104 | backend=args.dist_backend, init_method=args.dist_url,
105 | world_size=args.world_size, rank=args.rank)
106 | torch.distributed.barrier()
107 | setup_for_distributed(args.rank == 0)
108 |
109 |
110 | # Copyright (c) Facebook, Inc. and its affiliates.
111 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
112 | class GatherLayer(torch.autograd.Function):
113 | """
114 | Gather tensors from all workers with support for backward propagation:
115 | This implementation does not cut the gradients as torch.distributed.all_gather does.
116 | """
117 |
118 | @staticmethod
119 | def forward(ctx, x):
120 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
121 | dist.all_gather(output, x)
122 | return tuple(output)
123 |
124 | @staticmethod
125 | def backward(ctx, *grads):
126 | all_gradients = torch.stack(grads)
127 | dist.all_reduce(all_gradients)
128 | return all_gradients[dist.get_rank()]
129 |
130 |
131 | # copied from megavlt
132 | def gather_tensor_along_batch_with_backward(tensor, dim=0):
133 | world_size = get_world_size()
134 |
135 | if world_size < 2:
136 | return tensor
137 |
138 | tensor_list = GatherLayer.apply(tensor)
139 | tensor_list = torch.cat(tensor_list, dim=dim)
140 | return tensor_list
141 |
142 |
143 | @torch.no_grad()
144 | def gather_tensor_along_batch(tensor, dim=0):
145 | """
146 | Performs all_gather operation on the provided tensors.
147 | *** Warning ***: torch.distributed.all_gather has no gradient.
148 | """
149 | world_size = get_world_size()
150 |
151 | if world_size < 2:
152 | return tensor
153 |
154 | with torch.no_grad():
155 | tensor_list = []
156 |
157 | for _ in range(world_size):
158 | tensor_list.append(torch.zeros_like(tensor))
159 |
160 | dist.all_gather(tensor_list, tensor)
161 | tensor_list = torch.cat(tensor_list, dim=dim)
162 | return tensor_list
163 |
--------------------------------------------------------------------------------
/utils/easydict.py:
--------------------------------------------------------------------------------
1 | class EasyDict(dict):
2 | """
3 | Get attributes
4 |
5 | >>> d = EasyDict({'foo':3})
6 | >>> d['foo']
7 | 3
8 | >>> d.foo
9 | 3
10 | >>> d.bar
11 | Traceback (most recent call last):
12 | ...
13 | AttributeError: 'EasyDict' object has no attribute 'bar'
14 |
15 | Works recursively
16 |
17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
18 | >>> isinstance(d.bar, dict)
19 | True
20 | >>> d.bar.x
21 | 1
22 |
23 | Bullet-proof
24 |
25 | >>> EasyDict({})
26 | {}
27 | >>> EasyDict(d={})
28 | {}
29 | >>> EasyDict(None)
30 | {}
31 | >>> d = {'a': 1}
32 | >>> EasyDict(**d)
33 | {'a': 1}
34 |
35 | Set attributes
36 |
37 | >>> d = EasyDict()
38 | >>> d.foo = 3
39 | >>> d.foo
40 | 3
41 | >>> d.bar = {'prop': 'value'}
42 | >>> d.bar.prop
43 | 'value'
44 | >>> d
45 | {'foo': 3, 'bar': {'prop': 'value'}}
46 | >>> d.bar.prop = 'newer'
47 | >>> d.bar.prop
48 | 'newer'
49 |
50 |
51 | Values extraction
52 |
53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
54 | >>> isinstance(d.bar, list)
55 | True
56 | >>> from operator import attrgetter
57 | >>> map(attrgetter('x'), d.bar)
58 | [1, 3]
59 | >>> map(attrgetter('y'), d.bar)
60 | [2, 4]
61 | >>> d = EasyDict()
62 | >>> d.keys()
63 | []
64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
65 | >>> d.foo
66 | 3
67 | >>> d.bar.x
68 | 1
69 |
70 | Still like a dict though
71 |
72 | >>> o = EasyDict({'clean':True})
73 | >>> o.items()
74 | [('clean', True)]
75 |
76 | And like a class
77 |
78 | >>> class Flower(EasyDict):
79 | ... power = 1
80 | ...
81 | >>> f = Flower()
82 | >>> f.power
83 | 1
84 | >>> f = Flower({'height': 12})
85 | >>> f.height
86 | 12
87 | >>> f['power']
88 | 1
89 | >>> sorted(f.keys())
90 | ['height', 'power']
91 |
92 | update and pop items
93 | >>> d = EasyDict(a=1, b='2')
94 | >>> e = EasyDict(c=3.0, a=9.0)
95 | >>> d.update(e)
96 | >>> d.c
97 | 3.0
98 | >>> d['c']
99 | 3.0
100 | >>> d.get('c')
101 | 3.0
102 | >>> d.update(a=4, b=4)
103 | >>> d.b
104 | 4
105 | >>> d.pop('a')
106 | 4
107 | >>> d.a
108 | Traceback (most recent call last):
109 | ...
110 | AttributeError: 'EasyDict' object has no attribute 'a'
111 | """
112 |
113 | def __init__(self, d=None, **kwargs):
114 | if d is None:
115 | d = {}
116 | if kwargs:
117 | d.update(**kwargs)
118 | for k, v in d.items():
119 | setattr(self, k, v)
120 | # Class attributes
121 | for k in self.__class__.__dict__.keys():
122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
123 | setattr(self, k, getattr(self, k))
124 |
125 | def __setattr__(self, name, value):
126 | if isinstance(value, (list, tuple)):
127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
128 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
129 | value = self.__class__(value)
130 | super(EasyDict, self).__setattr__(name, value)
131 | super(EasyDict, self).__setitem__(name, value)
132 |
133 | __setitem__ = __setattr__
134 |
135 | def update(self, e=None, **f):
136 | d = e or dict()
137 | d.update(f)
138 | for k in d:
139 | setattr(self, k, d[k])
140 |
141 | def pop(self, k, d=None):
142 | if hasattr(self, k):
143 | delattr(self, k)
144 | return super(EasyDict, self).pop(k, d)
145 |
146 |
147 | if __name__ == "__main__":
148 | import doctest
149 |
150 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | import functools
5 | import logging
6 | import os
7 | import sys
8 | import time
9 | import wandb
10 | from typing import Any, Dict, Union
11 |
12 | import torch
13 | from .distributed import get_rank, is_main_process
14 | from termcolor import colored
15 |
16 |
17 | def log_dict_to_wandb(log_dict, step, prefix=""):
18 | """include a separator `/` at the end of `prefix`"""
19 | if not is_main_process():
20 | return
21 |
22 | log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()}
23 | wandb.log(log_dict, step)
24 |
25 |
26 | def setup_wandb(config):
27 | if not (config.wandb.enable and is_main_process()):
28 | return
29 |
30 | run = wandb.init(
31 | config=config,
32 | project=config.wandb.project,
33 | entity=config.wandb.entity,
34 | name=os.path.basename(config.output_dir),
35 | reinit=True
36 | )
37 | return run
38 |
39 |
40 | def setup_output_folder(save_dir: str, folder_only: bool = False):
41 | """Sets up and returns the output file where the logs will be placed
42 | based on the configuration passed. Usually "save_dir/logs/log_.txt".
43 | If env.log_dir is passed, logs will be directly saved in this folder.
44 | Args:
45 | folder_only (bool, optional): If folder should be returned and not the file.
46 | Defaults to False.
47 | Returns:
48 | str: folder or file path depending on folder_only flag
49 | """
50 | log_filename = "train_"
51 | log_filename += time.strftime("%Y_%m_%dT%H_%M_%S")
52 | log_filename += ".log"
53 |
54 | log_folder = os.path.join(save_dir, "logs")
55 |
56 | if not os.path.exists(log_folder):
57 | os.path.mkdirs(log_folder)
58 |
59 | if folder_only:
60 | return log_folder
61 |
62 | log_filename = os.path.join(log_folder, log_filename)
63 |
64 | return log_filename
65 |
66 |
67 | def setup_logger(
68 | output: str = None,
69 | color: bool = True,
70 | name: str = "mmf",
71 | disable: bool = False,
72 | clear_handlers=True,
73 | *args,
74 | **kwargs,
75 | ):
76 | """
77 | Initialize the MMF logger and set its verbosity level to "INFO".
78 | Outside libraries shouldn't call this in case they have set there
79 | own logging handlers and setup. If they do, and don't want to
80 | clear handlers, pass clear_handlers options.
81 | The initial version of this function was taken from D2 and adapted
82 | for MMF.
83 | Args:
84 | output (str): a file name or a directory to save log.
85 | If ends with ".txt" or ".log", assumed to be a file name.
86 | Default: Saved to file
87 | color (bool): If false, won't log colored logs. Default: true
88 | name (str): the root module name of this logger. Defaults to "mmf".
89 | disable: do not use
90 | clear_handlers (bool): If false, won't clear existing handlers.
91 | Returns:
92 | logging.Logger: a logger
93 | """
94 | if disable:
95 | return None
96 | logger = logging.getLogger(name)
97 | logger.propagate = False
98 |
99 | logging.captureWarnings(True)
100 | warnings_logger = logging.getLogger("py.warnings")
101 |
102 | plain_formatter = logging.Formatter(
103 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
104 | datefmt="%Y-%m-%dT%H:%M:%S",
105 | )
106 |
107 | distributed_rank = get_rank()
108 | handlers = []
109 |
110 | logging_level = logging.INFO
111 | # logging_level = logging.DEBUG
112 |
113 | if distributed_rank == 0:
114 | logger.setLevel(logging_level)
115 | ch = logging.StreamHandler(stream=sys.stdout)
116 | ch.setLevel(logging_level)
117 | if color:
118 | formatter = ColorfulFormatter(
119 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
120 | datefmt="%Y-%m-%dT%H:%M:%S",
121 | )
122 | else:
123 | formatter = plain_formatter
124 | ch.setFormatter(formatter)
125 | logger.addHandler(ch)
126 | warnings_logger.addHandler(ch)
127 | handlers.append(ch)
128 |
129 | # file logging: all workers
130 | if output is None:
131 | output = setup_output_folder()
132 |
133 | if output is not None:
134 | if output.endswith(".txt") or output.endswith(".log"):
135 | filename = output
136 | else:
137 | filename = os.path.join(output, "train.log")
138 | if distributed_rank > 0:
139 | filename = filename + f".rank{distributed_rank}"
140 | os.makedirs(os.path.dirname(filename), exist_ok=True)
141 |
142 | fh = logging.StreamHandler(_cached_log_stream(filename))
143 | fh.setLevel(logging_level)
144 | fh.setFormatter(plain_formatter)
145 | logger.addHandler(fh)
146 | warnings_logger.addHandler(fh)
147 | handlers.append(fh)
148 |
149 | # Slurm/FB output, only log the main process
150 | # save_dir = get_mmf_env(key="save_dir")
151 | if "train.log" not in filename and distributed_rank == 0:
152 | filename = os.path.join(output, "train.log")
153 | sh = logging.StreamHandler(_cached_log_stream(filename))
154 | sh.setLevel(logging_level)
155 | sh.setFormatter(plain_formatter)
156 | logger.addHandler(sh)
157 | warnings_logger.addHandler(sh)
158 | handlers.append(sh)
159 |
160 | logger.info(f"Logging to: {filename}")
161 |
162 | # Remove existing handlers to add MMF specific handlers
163 | if clear_handlers:
164 | for handler in logging.root.handlers[:]:
165 | logging.root.removeHandler(handler)
166 | # Now, add our handlers.
167 | logging.basicConfig(level=logging_level, handlers=handlers)
168 |
169 | return logger
170 |
171 |
172 | def setup_very_basic_config(color=True):
173 | plain_formatter = logging.Formatter(
174 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
175 | datefmt="%Y-%m-%dT%H:%M:%S",
176 | )
177 | ch = logging.StreamHandler(stream=sys.stdout)
178 | ch.setLevel(logging.INFO)
179 | if color:
180 | formatter = ColorfulFormatter(
181 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
182 | datefmt="%Y-%m-%dT%H:%M:%S",
183 | )
184 | else:
185 | formatter = plain_formatter
186 | ch.setFormatter(formatter)
187 | # Setup a minimal configuration for logging in case something tries to
188 | # log a message even before logging is setup by MMF.
189 | logging.basicConfig(level=logging.INFO, handlers=[ch])
190 |
191 |
192 | # cache the opened file object, so that different calls to `setup_logger`
193 | # with the same file name can safely write to the same file.
194 | @functools.lru_cache(maxsize=None)
195 | def _cached_log_stream(filename):
196 | return open(filename, "a")
197 |
198 |
199 | # ColorfulFormatter is adopted from Detectron2 and adapted for MMF
200 | class ColorfulFormatter(logging.Formatter):
201 | def __init__(self, *args, **kwargs):
202 | super().__init__(*args, **kwargs)
203 |
204 | def formatMessage(self, record):
205 | log = super().formatMessage(record)
206 | if record.levelno == logging.WARNING:
207 | prefix = colored("WARNING", "red", attrs=["blink"])
208 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
209 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
210 | else:
211 | return log
212 | return prefix + " " + log
213 |
214 |
215 | class TensorboardLogger:
216 | def __init__(self, log_folder="./logs", iteration=0):
217 | # This would handle warning of missing tensorboard
218 | from torch.utils.tensorboard import SummaryWriter
219 |
220 | self.summary_writer = None
221 | self._is_master = is_main_process()
222 | # self.timer = Timer()
223 | self.log_folder = log_folder
224 |
225 | if self._is_master:
226 | # current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
227 | current_time = time.strftime("%Y-%m-%dT%H:%M:%S")
228 | # self.timer.get_time_hhmmss(None, format=self.time_format)
229 | tensorboard_folder = os.path.join(
230 | self.log_folder, f"tensorboard_{current_time}"
231 | )
232 | self.summary_writer = SummaryWriter(tensorboard_folder)
233 |
234 | def __del__(self):
235 | if getattr(self, "summary_writer", None) is not None:
236 | self.summary_writer.close()
237 |
238 | def _should_log_tensorboard(self):
239 | if self.summary_writer is None or not self._is_master:
240 | return False
241 | else:
242 | return True
243 |
244 | def add_scalar(self, key, value, iteration):
245 | if not self._should_log_tensorboard():
246 | return
247 |
248 | self.summary_writer.add_scalar(key, value, iteration)
249 |
250 | def add_scalars(self, scalar_dict, iteration):
251 | if not self._should_log_tensorboard():
252 | return
253 |
254 | for key, val in scalar_dict.items():
255 | self.summary_writer.add_scalar(key, val, iteration)
256 |
257 | def add_histogram_for_model(self, model, iteration):
258 | if not self._should_log_tensorboard():
259 | return
260 |
261 | for name, param in model.named_parameters():
262 | np_param = param.clone().cpu().data.numpy()
263 | self.summary_writer.add_histogram(name, np_param, iteration)
264 |
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | """ Optimizer Factory w/ Custom Weight Decay
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import re
5 | import torch
6 | from torch import optim as optim
7 | from utils.distributed import is_main_process
8 | import logging
9 | logger = logging.getLogger(__name__)
10 | try:
11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
12 | has_apex = True
13 | except ImportError:
14 | has_apex = False
15 |
16 |
17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
18 | named_param_tuples = []
19 | for name, param in model.named_parameters():
20 | if not param.requires_grad:
21 | continue # frozen weights
22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
23 | named_param_tuples.append([name, param, 0])
24 | elif name in no_decay_list:
25 | named_param_tuples.append([name, param, 0])
26 | else:
27 | named_param_tuples.append([name, param, weight_decay])
28 | return named_param_tuples
29 |
30 |
31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
32 | """use lr=diff_lr for modules named found in diff_lr_names,
33 | otherwise use lr=default_lr
34 |
35 | Args:
36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
37 | diff_lr_names: List(str)
38 | diff_lr: float
39 | default_lr: float
40 | Returns:
41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr])
42 | """
43 | named_param_tuples_with_lr = []
44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
45 | for name, p, wd in named_param_tuples_or_model:
46 | use_diff_lr = False
47 | for diff_name in diff_lr_names:
48 | # if diff_name in name:
49 | if re.search(diff_name, name) is not None:
50 | logger.info(f"param {name} use different_lr: {diff_lr}")
51 | use_diff_lr = True
52 | break
53 |
54 | named_param_tuples_with_lr.append(
55 | [name, p, wd, diff_lr if use_diff_lr else default_lr]
56 | )
57 |
58 | if is_main_process():
59 | for name, _, wd, diff_lr in named_param_tuples_with_lr:
60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
61 |
62 | return named_param_tuples_with_lr
63 |
64 |
65 | def create_optimizer_params_group(named_param_tuples_with_lr):
66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
67 | group = {}
68 | for name, p, wd, lr in named_param_tuples_with_lr:
69 | if wd not in group:
70 | group[wd] = {}
71 | if lr not in group[wd]:
72 | group[wd][lr] = []
73 | group[wd][lr].append(p)
74 |
75 | optimizer_params_group = []
76 | for wd, lr_groups in group.items():
77 | for lr, p in lr_groups.items():
78 | optimizer_params_group.append(dict(
79 | params=p,
80 | weight_decay=wd,
81 | lr=lr
82 | ))
83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
84 | return optimizer_params_group
85 |
86 |
87 | def create_optimizer(args, model, filter_bias_and_bn=True):
88 | opt_lower = args.opt.lower()
89 | weight_decay = args.weight_decay
90 | # check for modules that requires different lr
91 | if hasattr(args, "different_lr") and args.different_lr.enable:
92 | diff_lr_module_names = args.different_lr.module_names
93 | diff_lr = args.different_lr.lr
94 | else:
95 | diff_lr_module_names = []
96 | diff_lr = None
97 |
98 | no_decay = {}
99 | if hasattr(model, 'no_weight_decay'):
100 | no_decay = model.no_weight_decay()
101 | named_param_tuples = add_weight_decay(
102 | model, weight_decay, no_decay, filter_bias_and_bn)
103 | named_param_tuples = add_different_lr(
104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
105 | parameters = create_optimizer_params_group(named_param_tuples)
106 |
107 | if 'fused' in opt_lower:
108 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
109 |
110 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
111 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
112 | opt_args['eps'] = args.opt_eps
113 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
114 | opt_args['betas'] = args.opt_betas
115 | if hasattr(args, 'opt_args') and args.opt_args is not None:
116 | opt_args.update(args.opt_args)
117 |
118 | opt_split = opt_lower.split('_')
119 | opt_lower = opt_split[-1]
120 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
121 | opt_args.pop('eps', None)
122 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
123 | elif opt_lower == 'momentum':
124 | opt_args.pop('eps', None)
125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
126 | elif opt_lower == 'adam':
127 | optimizer = optim.Adam(parameters, **opt_args)
128 | elif opt_lower == 'adamw':
129 | optimizer = optim.AdamW(parameters, **opt_args)
130 | else:
131 | assert False and "Invalid optimizer"
132 | raise ValueError
133 | return optimizer
134 |
--------------------------------------------------------------------------------
/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from torch.optim import Optimizer
5 | import math
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 |
9 | def create_scheduler(args, optimizer):
10 | lr_scheduler = None
11 | if args.sched == 'cosine':
12 | lr_scheduler = get_cosine_schedule_with_warmup(
13 | optimizer,
14 | num_warmup_steps=args.num_warmup_steps,
15 | num_training_steps=args.num_training_steps,
16 | num_cycles=0.5,
17 | min_lr_multi=args.min_lr_multi
18 | )
19 | return lr_scheduler
20 |
21 |
22 | def get_cosine_schedule_with_warmup(
23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
25 | ):
26 | """
27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
28 |
29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
31 | initial lr set in the optimizer.
32 | Args:
33 | optimizer ([`~torch.optim.Optimizer`]):
34 | The optimizer for which to schedule the learning rate.
35 | num_warmup_steps (`int`):
36 | The number of steps for the warmup phase.
37 | num_training_steps (`int`):
38 | The total number of training steps.
39 | num_cycles (`float`, *optional*, defaults to 0.5):
40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
41 | following a half-cosine).
42 | min_lr_multi (`float`, *optional*, defaults to 0):
43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
44 | last_epoch (`int`, *optional*, defaults to -1):
45 | The index of the last epoch when resuming training.
46 | Return:
47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
48 | """
49 |
50 | def lr_lambda(current_step):
51 | if current_step < num_warmup_steps:
52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
55 |
56 | return LambdaLR(optimizer, lr_lambda, last_epoch)
57 |
--------------------------------------------------------------------------------