├── open_flamingo-main ├── open_flamingo │ ├── src │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── factory.py │ │ ├── flamingo_lm.py │ │ ├── helpers.py │ │ └── flamingo.py │ ├── eval │ │ ├── __init__.py │ │ ├── coco_metric.py │ │ ├── rices.py │ │ ├── eval_model.py │ │ ├── utils.py │ │ ├── models │ │ │ ├── blip.py │ │ │ └── open_flamingo.py │ │ ├── README.md │ │ ├── eval_datasets.py │ │ └── ok_vqa_utils.py │ ├── train │ │ ├── __init__.py │ │ ├── custom_files │ │ │ ├── custom_utils.py │ │ │ ├── custom_factory.py │ │ │ ├── custom_flamingo_lm.py │ │ │ ├── custom_helpers.py │ │ │ └── custom_flamingo.py │ │ ├── distributed.py │ │ ├── data_utils.py │ │ ├── data.py │ │ └── train_utils.py │ ├── __init__.py │ └── scripts │ │ ├── run_train.sh │ │ ├── convert_mmc4_to_wds.py │ │ ├── run_eval.sh │ │ ├── fill_vqa_testdev_results.py │ │ └── cache_rices_features.py └── LICENSE ├── mdic └── mdic.pkl ├── LICENSE ├── README.md └── environment.yml /open_flamingo-main/open_flamingo/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mdic/mdic.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hee-suk-yoon/BI-MDRG/HEAD/mdic/mdic.pkl -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/__init__.py: -------------------------------------------------------------------------------- 1 | from .src.flamingo import Flamingo 2 | from .src.factory import create_model_and_transforms 3 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/coco_metric.py: -------------------------------------------------------------------------------- 1 | from pycocoevalcap.eval import COCOEvalCap 2 | from pycocotools.coco import COCO 3 | 4 | 5 | def compute_cider( 6 | result_path, 7 | annotations_path, 8 | ): 9 | # create coco object and coco_result object 10 | coco = COCO(annotations_path) 11 | coco_result = coco.loadRes(result_path) 12 | 13 | # create coco_eval object by taking coco and coco_result 14 | coco_eval = COCOEvalCap(coco, coco_result) 15 | coco_eval.params["image_id"] = coco_result.getImgIds() 16 | coco_eval.evaluate() 17 | 18 | return coco_eval.eval 19 | 20 | 21 | def postprocess_captioning_generation(predictions): 22 | return predictions.split("Output", 1)[0] 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Hee Suk Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes 1 3 | #SBATCH --ntasks-per-node=8 4 | #SBATCH --gpus-per-task=1 5 | 6 | export PYTHONFAULTHANDLER=1 7 | export CUDA_LAUNCH_BLOCKING=0 8 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 9 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | export MASTER_PORT=15000 11 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 12 | 13 | export PYTHONPATH="$PYTHONPATH:open_flamingo" 14 | srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \ 15 | --lm_path anas-awadalla/mpt-1b-redpajama-200b \ 16 | --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ 17 | --cross_attn_every_n_layers 1 \ 18 | --dataset_resampled \ 19 | --batch_size_mmc4 32 \ 20 | --batch_size_laion 64 \ 21 | --train_num_samples_mmc4 125000\ 22 | --train_num_samples_laion 250000 \ 23 | --loss_multiplier_laion 0.2 \ 24 | --workers=4 \ 25 | --run_name OpenFlamingo-3B-vitl-mpt1b \ 26 | --num_epochs 480 \ 27 | --warmup_steps 1875 \ 28 | --mmc4_textsim_threshold 0.24 \ 29 | --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ 30 | --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ 31 | --gradient_checkpointing \ 32 | --report_to_wandb \ 33 | -------------------------------------------------------------------------------- /open_flamingo-main/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/src/utils.py: -------------------------------------------------------------------------------- 1 | def extend_instance(obj, mixin): 2 | """Apply mixins to a class instance after creation""" 3 | base_cls = obj.__class__ 4 | base_cls_name = obj.__class__.__name__ 5 | obj.__class__ = type( 6 | base_cls_name, (mixin, base_cls), {} 7 | ) # mixin needs to go first for our forward() logic to work 8 | 9 | 10 | def getattr_recursive(obj, att): 11 | """ 12 | Return nested attribute of obj 13 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 14 | """ 15 | if att == "": 16 | return obj 17 | i = att.find(".") 18 | if i < 0: 19 | return getattr(obj, att) 20 | else: 21 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 22 | 23 | 24 | def setattr_recursive(obj, att, val): 25 | """ 26 | Set nested attribute of obj 27 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 28 | """ 29 | if "." in att: 30 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 31 | setattr(obj, att.split(".")[-1], val) 32 | 33 | 34 | def apply_with_stopping_condition( 35 | module, apply_fn, apply_condition=None, stopping_condition=None, **other_args 36 | ): 37 | if stopping_condition(module): 38 | return 39 | if apply_condition(module): 40 | apply_fn(module, **other_args) 41 | for child in module.children(): 42 | apply_with_stopping_condition( 43 | child, 44 | apply_fn, 45 | apply_condition=apply_condition, 46 | stopping_condition=stopping_condition, 47 | **other_args 48 | ) 49 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/custom_files/custom_utils.py: -------------------------------------------------------------------------------- 1 | def extend_instance(obj, mixin): 2 | """Apply mixins to a class instance after creation""" 3 | base_cls = obj.__class__ 4 | base_cls_name = obj.__class__.__name__ 5 | obj.__class__ = type( 6 | base_cls_name, (mixin, base_cls), {} 7 | ) # mixin needs to go first for our forward() logic to work 8 | 9 | 10 | def getattr_recursive(obj, att): 11 | """ 12 | Return nested attribute of obj 13 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 14 | """ 15 | if att == "": 16 | return obj 17 | i = att.find(".") 18 | if i < 0: 19 | return getattr(obj, att) 20 | else: 21 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 22 | 23 | 24 | def setattr_recursive(obj, att, val): 25 | """ 26 | Set nested attribute of obj 27 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 28 | """ 29 | if "." in att: 30 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 31 | setattr(obj, att.split(".")[-1], val) 32 | 33 | 34 | def apply_with_stopping_condition( 35 | module, apply_fn, apply_condition=None, stopping_condition=None, **other_args 36 | ): 37 | if stopping_condition(module): 38 | return 39 | if apply_condition(module): 40 | apply_fn(module, **other_args) 41 | for child in module.children(): 42 | apply_with_stopping_condition( 43 | child, 44 | apply_fn, 45 | apply_condition=apply_condition, 46 | stopping_condition=stopping_condition, 47 | **other_args 48 | ) 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BI-MDRG: Bridging Image History in Multimodal Dialogue Response Generation (ECCV 2024) 2 | 3 | This repository provides the official implementation of our ECCV 2024 paper: 4 | > BI-MDRG: Bridging Image History in Multimodal Dialogue Response Generation 5 | > Authors: Hee Suk Yoon*, Eunseop Yoon*, Joshua Tian Jin Tee*, Kang Zhang, Yu-Jung Heo, Du-Seong Chang, Chang D. Yoo 6 | 7 | The implementation is built upon [openflamingo](https://github.com/mlfoundations/open_flamingo). 8 | 9 | [[Paper Link]()] 10 | 11 | ## Installation 12 | ```bash 13 | # Clone this repo 14 | git clone https://github.com/hee-suk-yoon/BI-MDRG.git 15 | cd BI-MDRG 16 | 17 | # Create a conda enviroment 18 | 1. conda env create -f environment.yml 19 | 2. conda activate bimdrg 20 | ``` 21 | 22 | ## Datasets 23 | 1. Download the [MMDialog](https://github.com/victorsungo/MMDialog) dataset and prepare using the following preprocessing code 24 | 25 | 2. Prepare Citation Augmented Data 26 | 27 | 3. Multimodal Dialogue Image Consistency (MDIC) Dataset 28 | 29 | To evaluate the image consistency in multimodal dialogue, we have created a curated set of 300 dialogues annotated to track object consistency across conversations based on the MMDialog dataset. 30 | 31 | You can find the dataset at: `mdic/mdic.pkl` 32 | 33 | The dataset format is: `{dialogue_id: [citation_tags]}` 34 | 35 | 36 | ## Training 37 | 38 | ## Evaluation 39 | 40 | 41 | ## Acknowledgement 42 | This work was supported by a grant of the KAIST-KT joint research project through AI2X Lab., Tech Innovation Group, funded by KT (No. D23000019, Developing Visual and Language Capabilities for AI-Based Dialogue Systems), and by Institute for Information \& communications Technology Planning \& Evaluation (IITP) grant funded by the Korea government(MSIT) (No. 2021-0-01381, Development of Causal AI through Video Understanding and Reinforcement Learning, and Its Applications to Real Environments). 43 | 44 | Also, we thank the authors of the [OpenFlamingo](https://github.com/mlfoundations/open_flamingo), [Subject-Diffusion](https://github.com/OPPO-Mente-Lab/Subject-Diffusion), [MMDialog](https://github.com/victorsungo/MMDialog) for their open-source contributions. 45 | 46 | 47 | ## Contact 48 | If you have any questions, please feel free to email hskyoon@kaist.ac.kr -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/rices.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | from tqdm import tqdm 4 | import torch 5 | from utils import custom_collate_fn 6 | 7 | 8 | class RICES: 9 | def __init__( 10 | self, 11 | dataset, 12 | device, 13 | batch_size, 14 | vision_encoder_path="ViT-B-32", 15 | vision_encoder_pretrained="openai", 16 | cached_features=None, 17 | ): 18 | self.dataset = dataset 19 | self.device = device 20 | self.batch_size = batch_size 21 | 22 | # Load the model and processor 23 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 24 | vision_encoder_path, 25 | pretrained=vision_encoder_pretrained, 26 | ) 27 | self.model = vision_encoder.to(self.device) 28 | self.image_processor = image_processor 29 | 30 | # Precompute features 31 | if cached_features is None: 32 | self.features = self._precompute_features() 33 | else: 34 | self.features = cached_features 35 | 36 | def _precompute_features(self): 37 | features = [] 38 | 39 | # Switch to evaluation mode 40 | self.model.eval() 41 | 42 | # Set up loader 43 | loader = torch.utils.data.DataLoader( 44 | self.dataset, 45 | batch_size=self.batch_size, 46 | collate_fn=custom_collate_fn, 47 | ) 48 | 49 | with torch.no_grad(): 50 | for batch in tqdm( 51 | loader, 52 | desc="Precomputing features for RICES", 53 | ): 54 | batch = batch["image"] 55 | inputs = torch.stack( 56 | [self.image_processor(image) for image in batch] 57 | ).to(self.device) 58 | image_features = self.model.encode_image(inputs) 59 | image_features /= image_features.norm(dim=-1, keepdim=True) 60 | features.append(image_features.detach()) 61 | 62 | features = torch.cat(features) 63 | return features 64 | 65 | def find(self, batch, num_examples): 66 | """ 67 | Get the top num_examples most similar examples to the images. 68 | """ 69 | # Switch to evaluation mode 70 | self.model.eval() 71 | 72 | with torch.no_grad(): 73 | inputs = torch.stack([self.image_processor(image) for image in batch]).to( 74 | self.device 75 | ) 76 | 77 | # Get the feature of the input image 78 | query_feature = self.model.encode_image(inputs) 79 | query_feature /= query_feature.norm(dim=-1, keepdim=True) 80 | query_feature = query_feature.detach().cpu() 81 | 82 | if query_feature.ndim == 1: 83 | query_feature = query_feature.unsqueeze(0) 84 | 85 | # Compute the similarity of the input image to the precomputed features 86 | similarity = (query_feature @ self.features.T).squeeze() 87 | 88 | if similarity.ndim == 1: 89 | similarity = similarity.unsqueeze(0) 90 | 91 | # Get the indices of the 'num_examples' most similar images 92 | indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples] 93 | 94 | # Return with the most similar images last 95 | return [[self.dataset[i] for i in reversed(row)] for row in indices] 96 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/scripts/convert_mmc4_to_wds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import uuid 5 | import zipfile 6 | from PIL import Image 7 | import base64 8 | from io import BytesIO 9 | 10 | import braceexpand 11 | import webdataset as wds 12 | 13 | arg_parser = argparse.ArgumentParser() 14 | arg_parser.add_argument( 15 | "--output_dir", 16 | type=str, 17 | help="Pass in the directory where the output shards (as tar files) will be written to.", 18 | ) 19 | arg_parser.add_argument( 20 | "--zip_files", 21 | type=str, 22 | help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip", 23 | ) 24 | arg_parser.add_argument( 25 | "--image_dir", 26 | type=str, 27 | help="Pass in the directory where the images have been downloaded to.", 28 | ) 29 | arg_parser.add_argument( 30 | "--num_files_per_shard", 31 | type=int, 32 | default=1000, 33 | ) 34 | args = arg_parser.parse_args() 35 | 36 | 37 | def main(): 38 | os.makedirs(args.output_dir, exist_ok=True) 39 | 40 | doc_shards = list(braceexpand.braceexpand(args.zip_files)) 41 | 42 | with wds.ShardWriter(args.output_dir + "/%09d.tar") as sink: 43 | for idx in range(len(doc_shards)): 44 | # Open the ZIP archive and extract the JSON file 45 | with zipfile.ZipFile(doc_shards[idx], "r") as zip_file: 46 | # Assumes the JSON file is the first file in the archive 47 | json_filename = zip_file.namelist()[0] 48 | with zip_file.open(json_filename, "r") as json_file: 49 | for sample_data in json_file: 50 | # get image names from json 51 | sample_data = json.loads(sample_data) 52 | image_info = sample_data["image_info"] 53 | image_names = [image["image_name"] for image in image_info] 54 | 55 | # Add each image to the tar file 56 | for img_idx, image_name in enumerate(image_names): 57 | try: 58 | # load image 59 | img = Image.open( 60 | os.path.join(args.image_dir, str(idx), image_name) 61 | ).convert("RGB") 62 | buffered = BytesIO() 63 | img.save(buffered, format="JPEG") 64 | img_str = base64.b64encode(buffered.getvalue()) 65 | 66 | # convert to base64 67 | sample_data["image_info"][img_idx][ 68 | "image_base64" 69 | ] = img_str.decode("utf-8") 70 | except FileNotFoundError: 71 | print( 72 | f"Did not find {image_name} downloaded. This can happen if the url is now 404." 73 | ) 74 | except Exception as e: 75 | print(f"Error processing {image_name}: {e}") 76 | 77 | key_str = uuid.uuid4().hex 78 | sink.write({"__key__": key_str, "json": sample_data}) 79 | 80 | if (idx + 1) % args.num_files_per_shard == 0: 81 | sink.next_stream() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/eval_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import argparse 3 | from typing import List 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | from PIL import Image 6 | 7 | 8 | class BaseEvalModel(abc.ABC): 9 | """Base class encapsulating functionality needed to evaluate a model.""" 10 | 11 | def __init__(self, args: List[str]): 12 | """Initialize model. 13 | 14 | Args: 15 | args: arguments to model. These should be parsed, or if the model 16 | has no applicable arguments, an error should be thrown if `args` 17 | is non-empty. 18 | """ 19 | 20 | def init_distributed(self): 21 | """Wrap model as DDP.""" 22 | self.model = DDP(self.model, device_ids=[self.device]) 23 | 24 | def set_device(self, device): 25 | """Set device for model.""" 26 | self.device = device 27 | self.model = self.model.to(device) 28 | 29 | def get_outputs( 30 | self, 31 | batch_text: List[str], 32 | batch_images: List[List[Image.Image]], 33 | min_generation_length: int, 34 | max_generation_length: int, 35 | num_beams: int, 36 | length_penalty: float, 37 | ) -> List[str]: 38 | """Get outputs for a batch of images and text. 39 | 40 | Args: 41 | batch_text: list of text strings, with the text "" in place 42 | of any images to be included. 43 | batch_images: images to provide to model. Should be a list of lists, 44 | where each list contains the images for a single example. 45 | max_generation_length: maximum length of the generated caption. 46 | Defaults to 10. 47 | num_beams: number of beams to use for beam search. Defaults to 3. 48 | length_penalty: length penalty for beam search. Defaults to -2.0. 49 | 50 | Returns: 51 | List of decoded output strings. 52 | """ 53 | 54 | def vqa_prompt(self, question, answer=None) -> str: 55 | """Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model. 56 | 57 | Returns: 58 | The prompt to use for VQA. 59 | """ 60 | 61 | def caption_prompt(self, caption=None) -> str: 62 | """Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model. 63 | 64 | Returns: 65 | The prompt to use for captioning. 66 | """ 67 | 68 | def get_rank_classifications( 69 | self, 70 | batch_text: List[str], 71 | batch_images: List[List[Image.Image]], 72 | all_class_names: List[str], 73 | use_cache: bool, 74 | normalize_length: bool, 75 | ): 76 | """ 77 | Returns a (B, |all_class_names|) tensor containing the logprobs for each class name. 78 | Args: 79 | batch_text: list of text strings, with the text "" in place 80 | of any images to be included. 81 | batch_images: images to provide to model. Should be a list of lists, 82 | where each list contains the images for a single example. 83 | all_class_names: list of all class names. 84 | use_cache: whether to cache the context to speed up evaluations. 85 | normalize_length: whether to normalize logprobs by the length of the 86 | class name 87 | Returns: 88 | (B, |all_class_names|) tensor containing the logprobs for each class name. 89 | """ 90 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --ntasks-per-node=2 4 | #SBATCH --gpus-per-task=1 5 | 6 | < 0 else 2 31 | return num_shots 32 | 33 | 34 | def sample_batch_demos_from_query_set(query_set, num_samples, batch_size): 35 | """ 36 | Sample random demonstrations from the query set. 37 | """ 38 | return [random.sample(query_set, num_samples) for _ in range(batch_size)] 39 | 40 | 41 | def get_query_set(train_dataset, query_set_size): 42 | """ 43 | Get a subset of the training dataset to use as the query set. 44 | """ 45 | query_set = np.random.choice(len(train_dataset), query_set_size, replace=False) 46 | return [train_dataset[i] for i in query_set] 47 | 48 | 49 | def prepare_eval_samples(test_dataset, num_samples, batch_size): 50 | """ 51 | Subset the test dataset and return a DataLoader. 52 | """ 53 | random_indices = np.random.choice(len(test_dataset), num_samples, replace=False) 54 | dataset = torch.utils.data.Subset(test_dataset, random_indices) 55 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 56 | loader = torch.utils.data.DataLoader( 57 | dataset, 58 | batch_size=batch_size, 59 | sampler=sampler, 60 | collate_fn=custom_collate_fn, 61 | ) 62 | return loader 63 | 64 | 65 | def get_indices_of_unique(x): 66 | """ 67 | Return the indices of x that correspond to unique elements. 68 | If value v is unique and two indices in x have value v, the first index is returned. 69 | """ 70 | unique_elements = torch.unique(x) 71 | first_indices = [] 72 | for v in unique_elements: 73 | indices = torch.where(x == v)[0] 74 | first_indices.append(indices[0]) # Take the first index for each unique element 75 | return torch.tensor(first_indices) 76 | 77 | 78 | def unwrap_model(model): 79 | """ 80 | Unwrap a model from a DataParallel or DistributedDataParallel wrapper. 81 | """ 82 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): 83 | return model.module 84 | else: 85 | return model 86 | 87 | 88 | def get_predicted_classnames(logprobs, k, class_id_to_name): 89 | """ 90 | Args: 91 | - logprobs shape (B, Y) containing logprobs for each classname 92 | - k: number for top-k 93 | - class_id_to_name: dict mapping class index to classname 94 | 95 | Returns: 96 | - top-k predicted classnames shape (B, k) type str 97 | - top-k logprobs shape (B, k) type float 98 | """ 99 | # convert indices to classnames 100 | _, predictions = torch.topk(logprobs, k=k, dim=1) # shape (B, k) 101 | predicted_classnames = [ 102 | [class_id_to_name[ix] for ix in item] for item in predictions.tolist() 103 | ] 104 | predicted_logprobs = torch.gather(logprobs, 1, predictions) 105 | return predicted_classnames, predicted_logprobs 106 | 107 | 108 | def get_cast_dtype(precision: str): 109 | cast_dtype = None 110 | if precision == "bf16": 111 | cast_dtype = torch.bfloat16 112 | elif precision == "fp16": 113 | cast_dtype = torch.float16 114 | return cast_dtype 115 | 116 | 117 | def get_autocast(precision): 118 | if precision == "amp": 119 | return torch.cuda.amp.autocast 120 | elif precision == "amp_bfloat16" or precision == "amp_bf16": 121 | # amp_bfloat16 is more stable than amp float16 for clip training 122 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 123 | else: 124 | return suppress 125 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/models/blip.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from PIL import Image 4 | import torch 5 | 6 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 7 | from open_flamingo.eval.eval_model import BaseEvalModel 8 | from open_flamingo.eval.utils import unwrap_model 9 | 10 | 11 | class EvalModel(BaseEvalModel): 12 | """BLIP-2 model evaluation. 13 | 14 | Attributes: 15 | model (nn.Module): Underlying Torch model. 16 | tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. 17 | device: Index of GPU to use, or the string "cpu" 18 | """ 19 | 20 | def __init__(self, model_args): 21 | assert ( 22 | "processor_path" in model_args and "lm_path" in model_args 23 | ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified" 24 | 25 | self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) 26 | self.model = Blip2ForConditionalGeneration.from_pretrained( 27 | model_args["lm_path"] 28 | ) 29 | self.model.eval() 30 | self.processor.tokenizer.padding_side = "left" 31 | self.lm_name = model_args["lm_path"].split("/")[-1] 32 | 33 | def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: 34 | """Preprocess images and stack them. 35 | 36 | Args: 37 | batch: A list of lists of images. 38 | 39 | Returns: 40 | A Tensor of shape 41 | (batch_size, channels, height, width). 42 | """ 43 | batch_images = None 44 | assert all( 45 | len(example) == 1 for example in batch 46 | ), "BLIP-2 only supports one image per example" 47 | 48 | for example in batch: 49 | assert len(example) == 1, "BLIP-2 only supports one image per example" 50 | batch_images = torch.cat( 51 | [ 52 | batch_images, 53 | self.processor.image_processor(example, return_tensors="pt")[ 54 | "pixel_values" 55 | ], 56 | ] 57 | if batch_images is not None 58 | else [ 59 | self.processor.image_processor(example, return_tensors="pt")[ 60 | "pixel_values" 61 | ] 62 | ], 63 | dim=0, 64 | ) 65 | return batch_images 66 | 67 | def get_outputs( 68 | self, 69 | batch_text: List[str], 70 | batch_images: List[List[Image.Image]], 71 | min_generation_length: int, 72 | max_generation_length: int, 73 | num_beams: int, 74 | length_penalty: float, 75 | ) -> List[str]: 76 | encodings = self.processor.tokenizer( 77 | batch_text, 78 | padding="longest", 79 | truncation=True, 80 | return_tensors="pt", 81 | max_length=2000, 82 | ) 83 | input_ids = encodings["input_ids"] 84 | attention_mask = encodings["attention_mask"] 85 | 86 | with torch.inference_mode(): 87 | outputs = unwrap_model(self.model).generate( 88 | self._prepare_images(batch_images).to(self.device), 89 | input_ids.to(self.device), 90 | attention_mask=attention_mask.to(self.device), 91 | max_new_tokens=max_generation_length, 92 | min_new_tokens=min_generation_length, 93 | num_beams=num_beams, 94 | length_penalty=length_penalty, 95 | ) 96 | 97 | return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) 98 | 99 | def get_vqa_prompt(self, question, answer=None) -> str: 100 | return ( 101 | f"Question:{question} Short answer:{answer if answer is not None else ''}" 102 | ) 103 | 104 | def get_caption_prompt(self, caption=None) -> str: 105 | return f"A photo of {caption if caption is not None else ''}" 106 | 107 | def get_rank_classifications( 108 | self, 109 | batch_text: List[str], 110 | batch_images: List[List[Image.Image]], 111 | all_class_names: List[str], 112 | use_cache: bool, 113 | normalize_length: bool, 114 | ): 115 | raise NotImplementedError( 116 | "BLIP-2 classification-based evaluation not implemented" 117 | ) 118 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/README.md: -------------------------------------------------------------------------------- 1 | # OpenFlamingo Evaluation Suite 2 | 3 | This is the evaluation module of OpenFlamingo. It contains a set of utilities for evaluating multimodal models on various benchmarking datasets. 4 | 5 | *This module is a work in progress! We will be updating this README as it develops. In the meantime, if you notice an issue, please file a Bug Report or Feature Request [here](https://github.com/mlfoundations/open_flamingo/issues/new/choose).* 6 | 7 | ## Supported datasets 8 | 9 | |Dataset|Task|Metric|Evaluation method| 10 | |-------|----|------|-----------------| 11 | |[COCO](https://arxiv.org/abs/1405.0312)|Captioning|CIDEr|Generation| 12 | |[Flickr-30K](https://aclanthology.org/Q14-1006/)|Captioning|CIDEr|Generation| 13 | |[VQAv2](https://arxiv.org/abs/1612.00837v3)|VQA|VQA accuracy|Generation| 14 | |[OK-VQA](https://arxiv.org/abs/1906.00067)|VQA|VQA accuracy|Generation| 15 | |[TextVQA](https://arxiv.org/abs/1904.08920)|VQA|VQA accuracy|Generation| 16 | |[VizWiz](https://arxiv.org/abs/1802.08218)|VQA|VQA accuracy|Generation| 17 | |[Hateful Memes](https://arxiv.org/abs/2005.04790)|Classification|ROC AUC|Logprobs| 18 | |[ImageNet](https://arxiv.org/abs/1409.0575)|Classification|Top-1 accuracy|Logprobs| 19 | 20 | When evaluating a model using `num_shots` shots, we sample the exemplars from the training split. Performance is evaluated on a disjoint test split, subsampled to `--num_samples` examples (or using the full test split if `--num_samples=-1`). 21 | 22 | ## Sample scripts 23 | Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun`. We provide a sample Slurm evaluation script in `open_flamingo/open_flamingo/scripts/run_eval.sh`. 24 | 25 | We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16. 26 | 27 | To evaluate one of our pretrained checkpoints, we suggest first downloading a local copy of the weights, as follows: 28 | 29 | ``` 30 | # grab model checkpoint from huggingface hub 31 | from huggingface_hub import hf_hub_download 32 | HF_TOKEN="" 33 | 34 | checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt") 35 | checkpoint_path= hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", 36 | "checkpoint.pt", 37 | local_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b", 38 | cache_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b", 39 | local_dir_use_symlinks=False, 40 | token=HF_TOKEN) 41 | print(checkpoint_path) 42 | ## openflamingo/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt 43 | ``` 44 | 45 | This should place the OpenFlamingo model at the expected location in the evaluation script. 46 | 47 | For TextVQA and VizWiz we expect annotations to be formatted differently than the original datasets. We provide the custom annotations in `open_flamingo/open_flamingo/eval/data/`. We have also uploaded all the annotation files in a [huggingface dataset](https://huggingface.co/datasets/openflamingo/eval_benchmark/tree/main) for easy access. 48 | 49 | # Evaluating using RICES (Retrieval-based In-Context Example Selection) 50 | 51 | We provide the option to evaluate using RICES, which is a method for selecting exemplars from the training set based on image similarity. This method was used in DeepMind's implementation for evaluating on ImageNet, but can be used for any dataset in our evaluation suite. 52 | 53 | To use RICES, you must first create features for a benchmark's training set. We provide a script for doing so in `open_flamingo/open_flamingo/scripts/cache_rices_features.py`. This script will extract image features for a given dataset using a given CLIP model checkpoint. For example, to extract features for the COCO training set, you can run: 54 | 55 | ```bash 56 | python cache_rices_features.py \ 57 | --vision_encoder_path ViT-L-14 \ 58 | --vision_encoder_pretrained openai \ 59 | --batch_size 128 \ 60 | --eval_coco \ 61 | --coco_train_image_dir_path /path/to/coco/train2014 \ 62 | --coco_val_image_dir_path /path/to/coco/val2014 \ 63 | --coco_karpathy_json_path /path/to/coco/dataset_coco.json \ 64 | --coco_annotations_json_path /path/to/coco/annotations/captions_train2014.json \ 65 | --output_dir /path/to/coco/features 66 | ``` 67 | 68 | This will create a directory at `/path/to/coco/features` containing a file named `coco.pkl` with the extracted features. You can then use this directory to evaluate using RICES by passing the `--rices` flag to the evaluation script, specifying the path to the features directory using the `--cached_demonstration_features` flag, and specifying the vision encoder to use for RICES using the `--rices_vision_encoder_path` and `--rices_vision_encoder_pretrained` flags. 69 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for setting up distributed training. 3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py 4 | """ 5 | 6 | import os 7 | import torch 8 | 9 | try: 10 | import horovod.torch as hvd 11 | except ImportError: 12 | hvd = None 13 | 14 | 15 | def is_global_master(args): 16 | return args.rank == 0 17 | 18 | 19 | def is_local_master(args): 20 | return args.local_rank == 0 21 | 22 | 23 | def is_master(args, local=False): 24 | return is_local_master(args) if local else is_global_master(args) 25 | 26 | 27 | def is_using_horovod(): 28 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 29 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 30 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 31 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 32 | if all([var in os.environ for var in ompi_vars]) or all( 33 | [var in os.environ for var in pmi_vars] 34 | ): 35 | return True 36 | else: 37 | return False 38 | 39 | 40 | def is_using_distributed(): 41 | if "WORLD_SIZE" in os.environ: 42 | return int(os.environ["WORLD_SIZE"]) > 1 43 | if "SLURM_NTASKS" in os.environ: 44 | return int(os.environ["SLURM_NTASKS"]) > 1 45 | return False 46 | 47 | 48 | def world_info_from_env(): 49 | local_rank = 0 50 | for v in ( 51 | "LOCAL_RANK", 52 | "MPI_LOCALRANKID", 53 | "SLURM_LOCALID", 54 | "OMPI_COMM_WORLD_LOCAL_RANK", 55 | ): 56 | if v in os.environ: 57 | local_rank = int(os.environ[v]) 58 | break 59 | global_rank = 0 60 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 61 | if v in os.environ: 62 | global_rank = int(os.environ[v]) 63 | break 64 | world_size = 1 65 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 66 | if v in os.environ: 67 | world_size = int(os.environ[v]) 68 | break 69 | 70 | return local_rank, global_rank, world_size 71 | 72 | 73 | def init_distributed_device(args): 74 | # Distributed training = training on more than one GPU. 75 | # Works in both single and multi-node scenarios. 76 | args.distributed = False 77 | args.world_size = 1 78 | args.rank = 0 # global rank 79 | args.local_rank = 0 80 | if args.horovod: 81 | assert hvd is not None, "Horovod is not installed" 82 | hvd.init() 83 | args.local_rank = int(hvd.local_rank()) 84 | args.rank = hvd.rank() 85 | args.world_size = hvd.size() 86 | args.distributed = True 87 | os.environ["LOCAL_RANK"] = str(args.local_rank) 88 | os.environ["RANK"] = str(args.rank) 89 | os.environ["WORLD_SIZE"] = str(args.world_size) 90 | elif is_using_distributed(): 91 | if "SLURM_PROCID" in os.environ: 92 | # DDP via SLURM 93 | args.local_rank, args.rank, args.world_size = world_info_from_env() 94 | # SLURM var -> torch.distributed vars in case needed 95 | os.environ["LOCAL_RANK"] = str(args.local_rank) 96 | os.environ["RANK"] = str(args.rank) 97 | os.environ["WORLD_SIZE"] = str(args.world_size) 98 | torch.distributed.init_process_group( 99 | backend=args.dist_backend, 100 | init_method=args.dist_url, 101 | world_size=args.world_size, 102 | rank=args.rank, 103 | ) 104 | else: 105 | # DDP via torchrun, torch.distributed.launch 106 | args.local_rank, _, _ = world_info_from_env() 107 | torch.distributed.init_process_group( 108 | backend=args.dist_backend, init_method=args.dist_url 109 | ) 110 | args.world_size = torch.distributed.get_world_size() 111 | args.rank = torch.distributed.get_rank() 112 | args.distributed = True 113 | else: 114 | # needed to run on single gpu 115 | torch.distributed.init_process_group( 116 | backend=args.dist_backend, 117 | init_method=args.dist_url, 118 | world_size=1, 119 | rank=0, 120 | ) 121 | 122 | if torch.cuda.is_available(): 123 | if args.distributed and not args.no_set_device_rank: 124 | device = "cuda:%d" % args.local_rank 125 | else: 126 | device = "cuda:0" 127 | torch.cuda.set_device(device) 128 | else: 129 | device = "cpu" 130 | args.device = device 131 | device = torch.device(device) 132 | return device 133 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/scripts/fill_vqa_testdev_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission. 3 | Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set. 4 | Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction. 5 | """ 6 | import json 7 | import sys 8 | import os 9 | 10 | sys.path.append( 11 | os.path.join( 12 | os.path.dirname(os.path.abspath(__file__)), 13 | "..", 14 | ) 15 | ) 16 | from eval.vqa_metric import VQAEval 17 | 18 | postprocessor = VQAEval(None, None) 19 | 20 | 21 | def fill_vizwiz_test_json( 22 | input_path, 23 | output_path, 24 | vqa_test_questions_json_path, 25 | ): 26 | # read the input json and build a set with all question_ids 27 | with open(input_path, "r") as f: 28 | input_json = json.load(f) 29 | 30 | # postprocess answers 31 | question_id_to_answer = {} 32 | for q in input_json: 33 | resAns = q["answer"] 34 | resAns = resAns.replace("\n", " ") 35 | resAns = resAns.replace("\t", " ") 36 | resAns = resAns.strip() 37 | resAns = postprocessor.processPunctuation(resAns) 38 | resAns = postprocessor.processDigitArticle(resAns) 39 | question_id_to_answer[q["question_id"]] = resAns 40 | 41 | # read the vqa test json to get all the qustion_ids that need to be filled 42 | with open(vqa_test_questions_json_path, "r") as f: 43 | vqa_test_json = json.load(f) 44 | vqa_test_json = vqa_test_json["questions"] 45 | 46 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 47 | output_json = [] 48 | for q in vqa_test_json: 49 | output_json.append( 50 | { 51 | "image": q["image_id"], 52 | "answer": question_id_to_answer.get(q["question_id"], ""), 53 | } 54 | ) 55 | 56 | # write the json to the output path 57 | with open(output_path, "w") as f: 58 | json.dump(output_json, f) 59 | 60 | 61 | def fill_vqav2_test_json( 62 | input_path, 63 | output_path, 64 | vqa_test_questions_json_path, 65 | ): 66 | # read the input json and build a set with all question_ids 67 | with open(input_path, "r") as f: 68 | input_json = json.load(f) 69 | question_ids = set() 70 | for q in input_json: 71 | question_ids.add(q["question_id"]) 72 | 73 | # make a copy of the input json 74 | output_json = [] 75 | for q in input_json: 76 | resAns = q["answer"] 77 | resAns = resAns.replace("\n", " ") 78 | resAns = resAns.replace("\t", " ") 79 | resAns = resAns.strip() 80 | resAns = postprocessor.processPunctuation(resAns) 81 | resAns = postprocessor.processDigitArticle(resAns) 82 | q["answer"] = resAns 83 | output_json.append(q) 84 | 85 | # read the vqa test json to get all the qustion_ids that need to be filled 86 | with open(vqa_test_questions_json_path, "r") as f: 87 | vqa_test_json = json.load(f) 88 | vqa_test_json = vqa_test_json["questions"] 89 | 90 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 91 | for q in vqa_test_json: 92 | if q["question_id"] not in question_ids: 93 | output_json.append( 94 | { 95 | "question_id": q["question_id"], 96 | "answer": "", 97 | } 98 | ) 99 | 100 | # write the json to the output path 101 | with open(output_path, "w") as f: 102 | json.dump(output_json, f) 103 | 104 | 105 | if __name__ == "__main__": 106 | import argparse 107 | 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument( 110 | "--dataset", 111 | type=str, 112 | choices=["vqav2", "vizwiz"], 113 | ) 114 | parser.add_argument( 115 | "--input_path", 116 | type=str, 117 | help="Path to the json file with the subset of the vqa test-dev questions.", 118 | ) 119 | parser.add_argument( 120 | "--vqa_test_questions_json_path", 121 | type=str, 122 | help="Path to the json file with all the vqa test questions.", 123 | ) 124 | parser.add_argument( 125 | "--output_path", 126 | type=str, 127 | help="Path to store the filled json.", 128 | ) 129 | args = parser.parse_args() 130 | 131 | if args.dataset == "vqav2": 132 | fill_vqav2_test_json( 133 | args.input_path, 134 | args.output_path, 135 | args.vqa_test_questions_json_path, 136 | ) 137 | else: 138 | fill_vizwiz_test_json( 139 | args.input_path, 140 | args.output_path, 141 | args.vqa_test_questions_json_path, 142 | ) 143 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bimdrg 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.11=h7f8727e_2 15 | - pip=23.2.1=py39h06a4308_0 16 | - python=3.9.18=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.41.2=py39h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==2.0.0 26 | - accelerate==0.23.0 27 | - aiohttp==3.8.6 28 | - aiosignal==1.3.1 29 | - annotated-types==0.6.0 30 | - appdirs==1.4.4 31 | - astroid==3.0.0 32 | - asttokens==2.4.0 33 | - async-timeout==4.0.3 34 | - attrs==23.1.0 35 | - backcall==0.2.0 36 | - black==23.9.1 37 | - blessed==1.20.0 38 | - blis==0.7.11 39 | - braceexpand==0.1.7 40 | - catalogue==2.0.10 41 | - certifi==2023.7.22 42 | - charset-normalizer==3.3.0 43 | - click==8.1.7 44 | - cloudpathlib==0.16.0 45 | - cmake==3.27.6 46 | - confection==0.1.3 47 | - contourpy==1.1.1 48 | - cycler==0.12.0 49 | - cymem==2.0.8 50 | - datasets==2.14.5 51 | - decorator==5.1.1 52 | - diffusers==0.21.4 53 | - dill==0.3.7 54 | - docker-pycreds==0.4.0 55 | - einops==0.7.0 56 | - einops-exts==0.0.4 57 | #- en-core-web-lg==3.7.0 58 | - evaluate==0.4.1 59 | - exceptiongroup==1.1.3 60 | - executing==2.0.0 61 | - filelock==3.12.4 62 | - fonttools==4.43.0 63 | - frozenlist==1.4.0 64 | - fsspec==2023.6.0 65 | - ftfy==6.1.1 66 | - gitdb==4.0.10 67 | - gitpython==3.1.37 68 | - gpustat==1.1.1 69 | - huggingface-hub==0.17.3 70 | - idna==3.4 71 | - importlib-metadata==6.8.0 72 | - importlib-resources==6.1.0 73 | - inflection==0.5.1 74 | - iniconfig==2.0.0 75 | - ipdb==0.13.13 76 | - ipython==8.16.1 77 | - isort==5.12.0 78 | - jedi==0.19.1 79 | - jinja2==3.1.2 80 | - joblib==1.3.2 81 | - kiwisolver==1.4.5 82 | - langcodes==3.3.0 83 | - lightning-utilities==0.9.0 84 | - lit==17.0.1 85 | - markupsafe==2.1.3 86 | - matplotlib==3.8.0 87 | - matplotlib-inline==0.1.6 88 | - mccabe==0.7.0 89 | - mpmath==1.3.0 90 | - multidict==6.0.4 91 | - multiprocess==0.70.15 92 | - murmurhash==1.0.10 93 | - mypy==1.5.1 94 | - mypy-extensions==1.0.0 95 | - networkx==3.1 96 | - nltk==3.8.1 97 | - numpy==1.26.0 98 | - nvidia-cublas-cu11==11.10.3.66 99 | - nvidia-cuda-cupti-cu11==11.7.101 100 | - nvidia-cuda-nvrtc-cu11==11.7.99 101 | - nvidia-cuda-runtime-cu11==11.7.99 102 | - nvidia-cudnn-cu11==8.5.0.96 103 | - nvidia-cufft-cu11==10.9.0.58 104 | - nvidia-curand-cu11==10.2.10.91 105 | - nvidia-cusolver-cu11==11.4.0.1 106 | - nvidia-cusparse-cu11==11.7.4.91 107 | - nvidia-ml-py==12.535.108 108 | - nvidia-nccl-cu11==2.14.3 109 | - nvidia-nvtx-cu11==11.7.91 110 | - open-clip-torch==2.20.0 111 | - opencv-python==4.8.1.78 112 | - packaging==23.2 113 | - pandas==2.1.1 114 | - parso==0.8.3 115 | - pathspec==0.11.2 116 | - pathtools==0.1.2 117 | - pexpect==4.8.0 118 | - pickleshare==0.7.5 119 | - pillow==10.0.1 120 | - platformdirs==3.11.0 121 | - pluggy==1.3.0 122 | - preshed==3.0.9 123 | - prompt-toolkit==3.0.39 124 | - protobuf==3.20.3 125 | - psutil==5.9.5 126 | - ptyprocess==0.7.0 127 | - pure-eval==0.2.2 128 | - pyarrow==13.0.0 129 | - pycocoevalcap==1.2 130 | - pycocotools==2.0.7 131 | - pydantic==2.4.2 132 | - pydantic-core==2.10.1 133 | - pygments==2.16.1 134 | - pylint==3.0.0 135 | - pyparsing==3.1.1 136 | - pytest==7.4.2 137 | - python-dateutil==2.8.2 138 | - pytz==2023.3.post1 139 | - pyyaml==6.0.1 140 | - regex==2023.8.8 141 | - requests==2.31.0 142 | - responses==0.18.0 143 | - rouge-score==0.1.2 144 | - safetensors==0.3.3 145 | - scikit-learn==1.3.1 146 | - scipy==1.11.3 147 | - sentencepiece==0.1.98 148 | - sentry-sdk==1.31.0 149 | - setproctitle==1.3.2 150 | - six==1.16.0 151 | - smart-open==6.4.0 152 | - smmap==5.0.1 153 | - srsly==2.4.8 154 | - stack-data==0.6.3 155 | - sympy==1.12 156 | - thinc==8.2.1 157 | - threadpoolctl==3.2.0 158 | - timm==0.9.7 159 | - tokenizers==0.13.3 160 | - tomli==2.0.1 161 | - tomlkit==0.12.1 162 | - torch==2.0.1 163 | - torchmetrics==1.2.0 164 | - torchvision==0.15.2 165 | - tqdm==4.66.1 166 | - traitlets==5.10.1 167 | - transformers==4.33.3 168 | - triton==2.0.0 169 | - typer==0.9.0 170 | - typing-extensions==4.8.0 171 | - tzdata==2023.3 172 | - urllib3==2.0.6 173 | - wandb==0.15.11 174 | - wasabi==1.1.2 175 | - wcwidth==0.2.8 176 | - weasel==0.3.3 177 | - webdataset==0.2.57 178 | - xxhash==3.4.1 179 | - yarl==1.9.2 180 | - zipp==3.17.0 181 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/src/factory.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import open_clip 3 | 4 | from .flamingo import Flamingo 5 | #from open_flamingo.src.flamingo import Flamingo #added by HSY 10/6/23 6 | #from custom_files.custom_flamingo import Flamingo #added by HSY 10/6/23 7 | from .flamingo_lm import FlamingoLMMixin 8 | #from open_flamingo.src.flamingo_lm import FlamingoLMMixin #added by HSY 10/6/23 9 | from .utils import extend_instance 10 | #from open_flamingo.src.utils import extend_instance #added by HSY 10/6/23 11 | 12 | import ipdb # added by HSY 10/7/23 13 | 14 | def create_model_and_transforms( 15 | clip_vision_encoder_path: str, 16 | clip_vision_encoder_pretrained: str, 17 | lang_encoder_path: str, 18 | tokenizer_path: str, 19 | cross_attn_every_n_layers: int = 1, 20 | use_local_files: bool = False, 21 | decoder_layers_attr_name: str = None, 22 | freeze_lm_embeddings: bool = False, 23 | **flamingo_kwargs, 24 | ): 25 | import ipdb 26 | ipdb.set_trace() 27 | """ 28 | Initialize a Flamingo model from a pretrained vision encoder and language encoder. 29 | Appends special tokens to the tokenizer and freezes backbones. 30 | 31 | Args: 32 | clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") 33 | clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") 34 | lang_encoder_path (str): path to pretrained language encoder 35 | tokenizer_path (str): path to pretrained tokenizer 36 | cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. 37 | use_local_files (bool, optional): whether to use local files. Defaults to False. 38 | decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. 39 | Returns: 40 | Flamingo: Flamingo model from pretrained vision and language encoders 41 | Image processor: Pipeline to preprocess input images 42 | Tokenizer: A tokenizer for the language model 43 | """ 44 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 45 | clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained 46 | ) 47 | # set the vision encoder to output the visual features 48 | vision_encoder.visual.output_tokens = True 49 | 50 | text_tokenizer = AutoTokenizer.from_pretrained( 51 | tokenizer_path, 52 | local_files_only=use_local_files, 53 | trust_remote_code=True, 54 | ) 55 | # add Flamingo special tokens to the tokenizer 56 | text_tokenizer.add_special_tokens( 57 | {"additional_special_tokens": ["<|endofchunk|>", ""]} 58 | ) 59 | if text_tokenizer.pad_token is None: 60 | # Issue: GPT models don't have a pad token, which we use to 61 | # modify labels for the loss. 62 | text_tokenizer.add_special_tokens({"pad_token": ""}) 63 | 64 | lang_encoder = AutoModelForCausalLM.from_pretrained( 65 | lang_encoder_path, 66 | local_files_only=use_local_files, 67 | trust_remote_code=True, 68 | ) 69 | 70 | ipdb.set_trace() 71 | # hacks for MPT-1B, which doesn't have a get_input_embeddings method 72 | if "mpt-1b-redpajama-200b" in lang_encoder_path: 73 | 74 | class EmbeddingFnMixin: 75 | def get_input_embeddings(self): 76 | return self.transformer.wte 77 | 78 | def set_input_embeddings(self, new_embeddings): 79 | self.transformer.wte = new_embeddings 80 | 81 | extend_instance(lang_encoder, EmbeddingFnMixin) 82 | 83 | # convert LM to FlamingoLM 84 | extend_instance(lang_encoder, FlamingoLMMixin) 85 | 86 | if decoder_layers_attr_name is None: 87 | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) 88 | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) 89 | lang_encoder.resize_token_embeddings(len(text_tokenizer)) 90 | 91 | model = Flamingo( 92 | vision_encoder, 93 | lang_encoder, 94 | text_tokenizer.encode("<|endofchunk|>")[-1], 95 | text_tokenizer.encode("")[-1], 96 | vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ 97 | "width" 98 | ], 99 | cross_attn_every_n_layers=cross_attn_every_n_layers, 100 | **flamingo_kwargs, 101 | ) 102 | 103 | # Freeze all parameters 104 | model.requires_grad_(False) 105 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 106 | 107 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings 108 | model.perceiver.requires_grad_(True) 109 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) 110 | if not freeze_lm_embeddings: 111 | model.lang_encoder.get_input_embeddings().requires_grad_(True) 112 | # TODO: investigate also training the output embeddings when untied 113 | 114 | print( 115 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" 116 | ) 117 | 118 | return model, image_processor, text_tokenizer 119 | 120 | 121 | def _infer_decoder_layers_attr_name(model): 122 | for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: 123 | if k.lower() in model.__class__.__name__.lower(): 124 | return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] 125 | 126 | raise ValueError( 127 | f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." 128 | ) 129 | 130 | 131 | __KNOWN_DECODER_LAYERS_ATTR_NAMES = { 132 | "opt": "model.decoder.layers", 133 | "gptj": "transformer.h", 134 | "gpt-j": "transformer.h", 135 | "pythia": "gpt_neox.layers", 136 | "llama": "model.layers", 137 | "gptneoxforcausallm": "gpt_neox.layers", 138 | "mpt": "transformer.blocks", 139 | "mosaicgpt": "transformer.blocks", 140 | } 141 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/eval_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets import ImageFolder 7 | 8 | from open_flamingo.eval.classification_utils import IMAGENET_CLASSNAMES 9 | 10 | 11 | class CaptionDataset(Dataset): 12 | def __init__( 13 | self, 14 | image_train_dir_path, 15 | annotations_path, 16 | is_train, 17 | dataset_name, 18 | image_val_dir_path=None, 19 | ): 20 | self.image_train_dir_path = image_train_dir_path 21 | self.image_val_dir_path = image_val_dir_path 22 | self.annotations = [] 23 | self.is_train = is_train 24 | self.dataset_name = dataset_name 25 | 26 | full_annotations = json.load(open(annotations_path))["images"] 27 | 28 | for i in range(len(full_annotations)): 29 | if self.is_train and full_annotations[i]["split"] != "train": 30 | continue 31 | elif not self.is_train and full_annotations[i]["split"] != "test": 32 | continue 33 | 34 | self.annotations.append(full_annotations[i]) 35 | 36 | def __len__(self): 37 | return len(self.annotations) 38 | 39 | def __getitem__(self, idx): 40 | if self.dataset_name == "coco": 41 | image = Image.open( 42 | os.path.join( 43 | self.image_train_dir_path, self.annotations[idx]["filename"] 44 | ) 45 | if self.annotations[idx]["filepath"] == "train2014" 46 | else os.path.join( 47 | self.image_val_dir_path, self.annotations[idx]["filename"] 48 | ) 49 | ) 50 | elif self.dataset_name == "flickr": 51 | image = Image.open( 52 | os.path.join( 53 | self.image_train_dir_path, self.annotations[idx]["filename"] 54 | ) 55 | ) 56 | image.load() 57 | caption = self.annotations[idx]["sentences"][0]["raw"] 58 | return { 59 | "image": image, 60 | "caption": caption, 61 | "image_id": self.annotations[idx]["cocoid"] 62 | if self.dataset_name == "coco" 63 | else self.annotations[idx]["filename"].split(".")[0], 64 | } 65 | 66 | 67 | class VQADataset(Dataset): 68 | def __init__( 69 | self, image_dir_path, question_path, annotations_path, is_train, dataset_name 70 | ): 71 | self.questions = json.load(open(question_path, "r"))["questions"] 72 | if annotations_path is not None: 73 | self.answers = json.load(open(annotations_path, "r"))["annotations"] 74 | else: 75 | self.answers = None 76 | self.image_dir_path = image_dir_path 77 | self.is_train = is_train 78 | self.dataset_name = dataset_name 79 | if self.dataset_name in {"vqav2", "ok_vqa"}: 80 | self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1] 81 | assert self.img_coco_split in {"train2014", "val2014", "test2015"} 82 | 83 | def __len__(self): 84 | return len(self.questions) 85 | 86 | def get_img_path(self, question): 87 | if self.dataset_name in {"vqav2", "ok_vqa"}: 88 | return os.path.join( 89 | self.image_dir_path, 90 | f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg" 91 | if self.is_train 92 | else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg", 93 | ) 94 | elif self.dataset_name == "vizwiz": 95 | return os.path.join(self.image_dir_path, question["image_id"]) 96 | elif self.dataset_name == "textvqa": 97 | return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") 98 | else: 99 | raise Exception(f"Unknown VQA dataset {self.dataset_name}") 100 | 101 | def __getitem__(self, idx): 102 | question = self.questions[idx] 103 | img_path = self.get_img_path(question) 104 | image = Image.open(img_path) 105 | image.load() 106 | results = { 107 | "image": image, 108 | "question": question["question"], 109 | "question_id": question["question_id"], 110 | } 111 | if self.answers is not None: 112 | answers = self.answers[idx] 113 | results["answers"] = [a["answer"] for a in answers["answers"]] 114 | return results 115 | 116 | 117 | class ImageNetDataset(ImageFolder): 118 | """Class to represent the ImageNet1k dataset.""" 119 | 120 | def __init__(self, root, **kwargs): 121 | super().__init__(root=root, **kwargs) 122 | self.class_id_to_name = dict( 123 | zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES) 124 | ) 125 | 126 | def __getitem__(self, idx): 127 | sample, target = super().__getitem__(idx) 128 | target_label = self.class_id_to_name[target] 129 | return { 130 | "id": idx, 131 | "image": sample, 132 | "class_id": target, # numeric ID of the ImageNet class 133 | "class_name": target_label, # human-readable name of ImageNet class 134 | } 135 | 136 | 137 | class HatefulMemesDataset(Dataset): 138 | def __init__(self, image_dir_path, annotations_path): 139 | self.image_dir_path = image_dir_path 140 | with open(annotations_path, "r") as f: 141 | self.annotations = [json.loads(line) for line in f] 142 | 143 | def __len__(self): 144 | return len(self.annotations) 145 | 146 | def __getitem__(self, idx): 147 | annotation = self.annotations[idx] 148 | img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1]) 149 | image = Image.open(img_path) 150 | image.load() 151 | return { 152 | "id": annotation["id"], 153 | "image": image, 154 | "ocr": annotation["text"], 155 | "class_name": "yes" if annotation["label"] == 1 else "no", 156 | "class_id": annotation["label"], 157 | } 158 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/ok_vqa_utils.py: -------------------------------------------------------------------------------- 1 | # Those are manual mapping that are not caught by our stemming rules or would 2 | # would be done incorrectly by our automatic stemming rule. In details, 3 | # the keys of the _MANUAL_MATCHES dict contains the original word and the value 4 | # contains the transformation of the word expected by the OKVQA stemming rule. 5 | # These manual rules were found by checking the `raw_answers` and the `answers` 6 | # fields of the released OKVQA dataset and checking all things that were not 7 | # properly mapped by our automatic rules. In particular some of the mapping 8 | # are sometimes constant, e.g. christmas -> christmas which was incorrectly 9 | # singularized by our inflection.singularize. 10 | import re 11 | import nltk 12 | from nltk.corpus.reader import VERB 13 | import inflection 14 | 15 | _MANUAL_MATCHES = { 16 | "police": "police", 17 | "las": "las", 18 | "vegas": "vegas", 19 | "yes": "yes", 20 | "jeans": "jean", 21 | "hell's": "hell", 22 | "domino's": "domino", 23 | "morning": "morn", 24 | "clothes": "cloth", 25 | "are": "are", 26 | "riding": "ride", 27 | "leaves": "leaf", 28 | "dangerous": "danger", 29 | "clothing": "cloth", 30 | "texting": "text", 31 | "kiting": "kite", 32 | "firefighters": "firefight", 33 | "ties": "tie", 34 | "married": "married", 35 | "teething": "teeth", 36 | "gloves": "glove", 37 | "tennis": "tennis", 38 | "dining": "dine", 39 | "directions": "direct", 40 | "waves": "wave", 41 | "christmas": "christmas", 42 | "drives": "drive", 43 | "pudding": "pud", 44 | "coding": "code", 45 | "plating": "plate", 46 | "quantas": "quanta", 47 | "hornes": "horn", 48 | "graves": "grave", 49 | "mating": "mate", 50 | "paned": "pane", 51 | "alertness": "alert", 52 | "sunbathing": "sunbath", 53 | "tenning": "ten", 54 | "wetness": "wet", 55 | "urinating": "urine", 56 | "sickness": "sick", 57 | "braves": "brave", 58 | "firefighting": "firefight", 59 | "lenses": "lens", 60 | "reflections": "reflect", 61 | "backpackers": "backpack", 62 | "eatting": "eat", 63 | "designers": "design", 64 | "curiousity": "curious", 65 | "playfulness": "play", 66 | "blindness": "blind", 67 | "hawke": "hawk", 68 | "tomatoe": "tomato", 69 | "rodeoing": "rodeo", 70 | "brightness": "bright", 71 | "circuses": "circus", 72 | "skateboarders": "skateboard", 73 | "staring": "stare", 74 | "electronics": "electron", 75 | "electicity": "elect", 76 | "mountainous": "mountain", 77 | "socializing": "social", 78 | "hamburgers": "hamburg", 79 | "caves": "cave", 80 | "transitions": "transit", 81 | "wading": "wade", 82 | "creame": "cream", 83 | "toileting": "toilet", 84 | "sautee": "saute", 85 | "buildings": "build", 86 | "belongings": "belong", 87 | "stockings": "stock", 88 | "walle": "wall", 89 | "cumulis": "cumuli", 90 | "travelers": "travel", 91 | "conducter": "conduct", 92 | "browsing": "brows", 93 | "pooping": "poop", 94 | "haircutting": "haircut", 95 | "toppings": "top", 96 | "hearding": "heard", 97 | "sunblocker": "sunblock", 98 | "bases": "base", 99 | "markings": "mark", 100 | "mopeds": "mope", 101 | "kindergartener": "kindergarten", 102 | "pies": "pie", 103 | "scrapbooking": "scrapbook", 104 | "couponing": "coupon", 105 | "meetings": "meet", 106 | "elevators": "elev", 107 | "lowes": "low", 108 | "men's": "men", 109 | "childrens": "children", 110 | "shelves": "shelve", 111 | "paintings": "paint", 112 | "raines": "rain", 113 | "paring": "pare", 114 | "expressions": "express", 115 | "routes": "rout", 116 | "pease": "peas", 117 | "vastness": "vast", 118 | "awning": "awn", 119 | "boy's": "boy", 120 | "drunkenness": "drunken", 121 | "teasing": "teas", 122 | "conferences": "confer", 123 | "ripeness": "ripe", 124 | "suspenders": "suspend", 125 | "earnings": "earn", 126 | "reporters": "report", 127 | "kid's": "kid", 128 | "containers": "contain", 129 | "corgie": "corgi", 130 | "porche": "porch", 131 | "microwaves": "microwave", 132 | "batter's": "batter", 133 | "sadness": "sad", 134 | "apartments": "apart", 135 | "oxygenize": "oxygen", 136 | "striping": "stripe", 137 | "purring": "pure", 138 | "professionals": "profession", 139 | "piping": "pipe", 140 | "farmer's": "farmer", 141 | "potatoe": "potato", 142 | "emirates": "emir", 143 | "womens": "women", 144 | "veteran's": "veteran", 145 | "wilderness": "wilder", 146 | "propellers": "propel", 147 | "alpes": "alp", 148 | "charioteering": "chariot", 149 | "swining": "swine", 150 | "illness": "ill", 151 | "crepte": "crept", 152 | "adhesives": "adhesive", 153 | "regent's": "regent", 154 | "decorations": "decor", 155 | "rabbies": "rabbi", 156 | "overseas": "oversea", 157 | "travellers": "travel", 158 | "casings": "case", 159 | "smugness": "smug", 160 | "doves": "dove", 161 | "nationals": "nation", 162 | "mustange": "mustang", 163 | "ringe": "ring", 164 | "gondoliere": "gondolier", 165 | "vacationing": "vacate", 166 | "reminders": "remind", 167 | "baldness": "bald", 168 | "settings": "set", 169 | "glaced": "glace", 170 | "coniferous": "conifer", 171 | "revelations": "revel", 172 | "personals": "person", 173 | "daughter's": "daughter", 174 | "badness": "bad", 175 | "projections": "project", 176 | "polarizing": "polar", 177 | "vandalizers": "vandal", 178 | "minerals": "miner", 179 | "protesters": "protest", 180 | "controllers": "control", 181 | "weddings": "wed", 182 | "sometimes": "sometime", 183 | "earing": "ear", 184 | } 185 | 186 | 187 | class OKVQAStemmer: 188 | """Stemmer to match OKVQA v1.1 procedure.""" 189 | 190 | def __init__(self): 191 | self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer() 192 | 193 | def stem(self, input_string): 194 | """Apply stemming.""" 195 | word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string)) 196 | stemmed_words = [] 197 | for w, p in word_and_pos: 198 | if w in _MANUAL_MATCHES: 199 | w = _MANUAL_MATCHES[w] 200 | elif w.endswith("ing"): 201 | w = self._wordnet_lemmatizer.lemmatize(w, VERB) 202 | elif p.startswith("NNS") or p.startswith("NNPS"): 203 | w = inflection.singularize(w) 204 | stemmed_words.append(w) 205 | return " ".join(stemmed_words) 206 | 207 | 208 | stemmer = OKVQAStemmer() 209 | 210 | 211 | def postprocess_ok_vqa_generation(predictions) -> str: 212 | prediction = re.split("Question|Answer|Short", predictions, 1)[0] 213 | prediction = re.split(", ", prediction, 1)[0] 214 | prediction_stem = stemmer.stem(prediction) 215 | return prediction_stem 216 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/src/flamingo_lm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .helpers import GatedCrossAttentionBlock 3 | from .utils import getattr_recursive, setattr_recursive 4 | 5 | 6 | class FlamingoLayer(nn.Module): 7 | """ 8 | FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. 9 | """ 10 | 11 | def __init__( 12 | self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False 13 | ): 14 | super().__init__() 15 | self.gated_cross_attn_layer = gated_cross_attn_layer 16 | self.decoder_layer = decoder_layer 17 | self.vis_x = None 18 | self.media_locations = None 19 | if self.gated_cross_attn_layer is not None: 20 | self.gated_cross_attn_layer._use_gradient_checkpointing = ( 21 | gradient_checkpointing 22 | ) 23 | self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing 24 | 25 | def is_conditioned(self) -> bool: 26 | """Check whether the layer is conditioned.""" 27 | return self.vis_x is not None and self.media_locations is not None 28 | 29 | # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) 30 | def condition_vis_x(self, vis_x): 31 | self.vis_x = vis_x 32 | 33 | def condition_media_locations(self, media_locations): 34 | self.media_locations = media_locations 35 | 36 | def condition_use_cached_media(self, use_cached_media): 37 | self.use_cached_media = use_cached_media 38 | 39 | def forward( 40 | self, 41 | lang_x, 42 | attention_mask=None, 43 | **decoder_layer_kwargs, 44 | ): 45 | # Cross attention 46 | if self.gated_cross_attn_layer is not None: 47 | if self.vis_x is None: 48 | raise ValueError("vis_x must be conditioned before forward pass") 49 | 50 | if self.media_locations is None: 51 | raise ValueError( 52 | "media_locations must be conditioned before forward pass" 53 | ) 54 | 55 | lang_x = self.gated_cross_attn_layer( 56 | lang_x, 57 | self.vis_x, 58 | media_locations=self.media_locations, 59 | use_cached_media=self.use_cached_media, 60 | ) 61 | 62 | # Normal decoder layer 63 | lang_x = self.decoder_layer( 64 | lang_x, attention_mask=attention_mask, **decoder_layer_kwargs 65 | ) 66 | return lang_x 67 | 68 | 69 | class FlamingoLMMixin(nn.Module): 70 | """ 71 | Mixin to add cross-attention layers to a language model. 72 | """ 73 | 74 | def set_decoder_layers_attr_name(self, decoder_layers_attr_name): 75 | self.decoder_layers_attr_name = decoder_layers_attr_name 76 | 77 | def _get_decoder_layers(self): 78 | return getattr_recursive(self, self.decoder_layers_attr_name) 79 | 80 | def _set_decoder_layers(self, value): 81 | setattr_recursive(self, self.decoder_layers_attr_name, value) 82 | 83 | def init_flamingo( 84 | self, 85 | media_token_id, 86 | lang_hidden_size, 87 | vis_hidden_size, 88 | cross_attn_every_n_layers, 89 | gradient_checkpointing, 90 | ): 91 | """ 92 | Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. 93 | """ 94 | self.old_decoder_blocks = self._get_decoder_layers() 95 | self.gated_cross_attn_layers = nn.ModuleList( 96 | [ 97 | GatedCrossAttentionBlock( 98 | dim=lang_hidden_size, dim_visual=vis_hidden_size 99 | ) 100 | if (layer_idx + 1) % cross_attn_every_n_layers == 0 101 | else None 102 | for layer_idx, _ in enumerate(self._get_decoder_layers()) 103 | ] 104 | ) 105 | self.init_flamingo_layers(gradient_checkpointing) 106 | self.media_token_id = media_token_id 107 | self.initialized_flamingo = True 108 | self._use_cached_vision_x = False 109 | 110 | def init_flamingo_layers(self, gradient_checkpointing): 111 | """ 112 | Re initializes the FlamingoLayers. 113 | Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks 114 | """ 115 | self._set_decoder_layers( 116 | nn.ModuleList( 117 | [ 118 | FlamingoLayer( 119 | gated_cross_attn_layer, decoder_layer, gradient_checkpointing 120 | ) 121 | for gated_cross_attn_layer, decoder_layer in zip( 122 | self.gated_cross_attn_layers, self.old_decoder_blocks 123 | ) 124 | ] 125 | ) 126 | ) 127 | 128 | def forward(self, input_ids, attention_mask, **kwargs): 129 | """Condition the Flamingo layers on the media locations before forward()""" 130 | if not self.initialized_flamingo: 131 | raise ValueError( 132 | "Flamingo layers are not initialized. Please call `init_flamingo` first." 133 | ) 134 | 135 | media_locations = input_ids == self.media_token_id 136 | 137 | # if there are media already cached and we're generating and there are no media tokens in the input, 138 | # we'll assume that ALL input tokens should attend to the last previous media that is cached. 139 | # this is especially important for HF generate() compatibility, since generate() calls forward() 140 | # repeatedly one token at a time (with no media tokens). 141 | # without this check, the model would not attend to any images when generating (after the first token) 142 | use_cached_media_locations = ( 143 | self._use_cached_vision_x 144 | and self.is_conditioned() 145 | and not media_locations.any() 146 | ) 147 | 148 | for layer in self._get_decoder_layers(): 149 | if not use_cached_media_locations: 150 | layer.condition_media_locations(media_locations) 151 | layer.condition_use_cached_media(use_cached_media_locations) 152 | 153 | # package arguments for the other parent's forward. since we don't know the order of the arguments, 154 | # make them all kwargs 155 | kwargs["input_ids"] = input_ids 156 | kwargs["attention_mask"] = attention_mask 157 | return super().forward(**kwargs) # Call the other parent's forward method 158 | 159 | def is_conditioned(self) -> bool: 160 | """Check whether all decoder layers are already conditioned.""" 161 | return all(l.is_conditioned() for l in self._get_decoder_layers()) 162 | 163 | def clear_conditioned_layers(self): 164 | for layer in self._get_decoder_layers(): 165 | layer.condition_vis_x(None) 166 | layer.condition_media_locations(None) 167 | layer.condition_use_cached_media(None) 168 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/custom_files/custom_factory.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import open_clip 3 | 4 | #from .flamingo import Flamingo 5 | #from open_flamingo.src.flamingo import Flamingo #added by HSY 10/6/23 6 | from custom_files.custom_flamingo import Flamingo #added by HSY 10/6/23 7 | 8 | #from .flamingo_lm import FlamingoLMMixin 9 | #from open_flamingo.src.flamingo_lm import FlamingoLMMixin #added by HSY 10/6/23 10 | from custom_files.custom_flamingo_lm import FlamingoLMMixin #added by HSY 10/7/23 11 | 12 | #from .utils import extend_instance 13 | #from open_flamingo.src.utils import extend_instance #added by HSY 10/6/23 14 | from custom_files.custom_utils import extend_instance #added by HSY 10/7/23 15 | 16 | from custom_files.custom_modeling_gpt_neox import GPTNeoXForCausalLM #added by HSY 10/7/23 17 | from custom_files.custom_modeling_mpt import MptForCausalLM #added by HSY 10/7/23 18 | import ipdb # added by HSY 10/7/23 19 | 20 | def create_model_and_transforms( 21 | args, 22 | clip_vision_encoder_path: str, 23 | clip_vision_encoder_pretrained: str, 24 | lang_encoder_path: str, 25 | tokenizer_path: str, 26 | cross_attn_every_n_layers: int = 1, 27 | use_local_files: bool = False, 28 | decoder_layers_attr_name: str = None, 29 | freeze_lm_embeddings: bool = False, 30 | **flamingo_kwargs, 31 | ): 32 | """ 33 | Initialize a Flamingo model from a pretrained vision encoder and language encoder. 34 | Appends special tokens to the tokenizer and freezes backbones. 35 | 36 | Args: 37 | clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") 38 | clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") 39 | lang_encoder_path (str): path to pretrained language encoder 40 | tokenizer_path (str): path to pretrained tokenizer 41 | cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. 42 | use_local_files (bool, optional): whether to use local files. Defaults to False. 43 | decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. 44 | Returns: 45 | Flamingo: Flamingo model from pretrained vision and language encoders 46 | Image processor: Pipeline to preprocess input images 47 | Tokenizer: A tokenizer for the language model 48 | """ 49 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 50 | clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained 51 | ) 52 | # set the vision encoder to output the visual features 53 | vision_encoder.visual.output_tokens = True 54 | 55 | text_tokenizer = AutoTokenizer.from_pretrained( 56 | tokenizer_path, 57 | local_files_only=use_local_files, 58 | trust_remote_code=True, 59 | ) 60 | 61 | # add token --------------------------------- 62 | # add Flamingo special tokens to the tokenizer 63 | text_tokenizer.add_special_tokens( 64 | {"additional_special_tokens": ["<|endofchunk|>", ""]} 65 | ) 66 | # -------------------------------------------------------------------------- 67 | if text_tokenizer.pad_token is None: 68 | # Issue: GPT models don't have a pad token, which we use to 69 | # modify labels for the loss. 70 | text_tokenizer.add_special_tokens({"pad_token": ""}) 71 | 72 | """ 73 | lang_encoder = AutoModelForCausalLM.from_pretrained( 74 | lang_encoder_path, 75 | local_files_only=use_local_files, 76 | trust_remote_code=True, 77 | ) 78 | """ 79 | 80 | if "mpt-7b" in lang_encoder_path: 81 | lang_encoder = MptForCausalLM.from_pretrained( 82 | lang_encoder_path, 83 | local_files_only=use_local_files, 84 | trust_remote_code=True, 85 | ) 86 | 87 | elif "RedPajama-INCITE-Instruct-3B-v1" in lang_encoder_path: 88 | lang_encoder = GPTNeoXForCausalLM.from_pretrained( 89 | lang_encoder_path, 90 | local_files_only=use_local_files, 91 | trust_remote_code=True, 92 | ) 93 | 94 | # hacks for MPT-1B, which doesn't have a get_input_embeddings method 95 | if "mpt-1b-redpajama-200b" in lang_encoder_path: 96 | 97 | class EmbeddingFnMixin: 98 | def get_input_embeddings(self): 99 | return self.transformer.wte 100 | 101 | def set_input_embeddings(self, new_embeddings): 102 | self.transformer.wte = new_embeddings 103 | 104 | extend_instance(lang_encoder, EmbeddingFnMixin) 105 | 106 | # convert LM to FlamingoLM 107 | #if not args.baseline_no_image: 108 | extend_instance(lang_encoder, FlamingoLMMixin) 109 | if decoder_layers_attr_name is None: 110 | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) 111 | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) 112 | lang_encoder.resize_token_embeddings(len(text_tokenizer)) 113 | 114 | model = Flamingo( 115 | args, 116 | vision_encoder, 117 | lang_encoder, 118 | text_tokenizer.encode("<|endofchunk|>")[-1], 119 | text_tokenizer.encode("")[-1], 120 | vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ 121 | "width" 122 | ], 123 | cross_attn_every_n_layers=cross_attn_every_n_layers, 124 | **flamingo_kwargs, 125 | ) 126 | 127 | 128 | 129 | # Freeze all parameters 130 | model.requires_grad_(False) 131 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 132 | 133 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings 134 | if not args.baseline_no_image: 135 | model.perceiver.requires_grad_(True) 136 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) 137 | if not freeze_lm_embeddings: 138 | model.lang_encoder.get_input_embeddings().requires_grad_(True) 139 | # TODO: investigate also training the output embeddings when untied 140 | 141 | if args.full_tuning: 142 | model.lang_encoder.old_decoder_blocks.requires_grad_(True) 143 | else: 144 | model.lang_encoder.old_decoder_blocks.requires_grad_(True) 145 | if not freeze_lm_embeddings: 146 | model.lang_encoder.get_input_embeddings().requires_grad_(True) 147 | # TODO: investigate also training the output embeddings when untied 148 | 149 | 150 | 151 | print( 152 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" 153 | ) 154 | 155 | return model, image_processor, text_tokenizer 156 | 157 | 158 | def _infer_decoder_layers_attr_name(model): 159 | for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: 160 | if k.lower() in model.__class__.__name__.lower(): 161 | return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] 162 | 163 | raise ValueError( 164 | f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." 165 | ) 166 | 167 | 168 | __KNOWN_DECODER_LAYERS_ATTR_NAMES = { 169 | "opt": "model.decoder.layers", 170 | "gptj": "transformer.h", 171 | "gpt-j": "transformer.h", 172 | "pythia": "gpt_neox.layers", 173 | "llama": "model.layers", 174 | "gptneoxforcausallm": "gpt_neox.layers", 175 | "mpt": "transformer.blocks", 176 | "mosaicgpt": "transformer.blocks", 177 | } 178 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/custom_files/custom_flamingo_lm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | #from .helpers import GatedCrossAttentionBlock 3 | from custom_files.custom_helpers import GatedCrossAttentionBlock 4 | 5 | #from .utils import getattr_recursive, setattr_recursive 6 | from custom_files.custom_utils import getattr_recursive, setattr_recursive 7 | 8 | 9 | 10 | class FlamingoLayer(nn.Module): 11 | """ 12 | FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. 13 | """ 14 | 15 | def __init__( 16 | self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False 17 | ): 18 | super().__init__() 19 | self.gated_cross_attn_layer = gated_cross_attn_layer 20 | self.decoder_layer = decoder_layer 21 | self.vis_x = None 22 | self.media_locations = None 23 | if self.gated_cross_attn_layer is not None: 24 | self.gated_cross_attn_layer._use_gradient_checkpointing = ( 25 | gradient_checkpointing 26 | ) 27 | self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing 28 | 29 | def is_conditioned(self) -> bool: 30 | """Check whether the layer is conditioned.""" 31 | return self.vis_x is not None and self.media_locations is not None 32 | 33 | # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) 34 | def condition_vis_x(self, vis_x): 35 | self.vis_x = vis_x 36 | 37 | def condition_media_locations(self, media_locations): 38 | self.media_locations = media_locations 39 | 40 | def condition_use_cached_media(self, use_cached_media): 41 | self.use_cached_media = use_cached_media 42 | 43 | def forward( 44 | self, 45 | lang_x, 46 | attention_mask=None, 47 | mask_map=None, 48 | mask_image_map=None, 49 | **decoder_layer_kwargs, 50 | ): 51 | # Cross attention 52 | if self.gated_cross_attn_layer is not None: 53 | if self.vis_x is None: 54 | raise ValueError("vis_x must be conditioned before forward pass") 55 | 56 | if self.media_locations is None: 57 | raise ValueError( 58 | "media_locations must be conditioned before forward pass" 59 | ) 60 | 61 | lang_x = self.gated_cross_attn_layer( 62 | lang_x, 63 | self.vis_x, 64 | media_locations=self.media_locations, 65 | use_cached_media=self.use_cached_media, 66 | custom_text_to_image_mask=mask_image_map, 67 | ) 68 | 69 | # Normal decoder layer 70 | lang_x = self.decoder_layer( 71 | lang_x, attention_mask=attention_mask, mask_map=mask_map, **decoder_layer_kwargs 72 | ) 73 | return lang_x 74 | 75 | 76 | class FlamingoLMMixin(nn.Module): 77 | """ 78 | Mixin to add cross-attention layers to a language model. 79 | """ 80 | 81 | def set_decoder_layers_attr_name(self, decoder_layers_attr_name): 82 | self.decoder_layers_attr_name = decoder_layers_attr_name 83 | 84 | def _get_decoder_layers(self): 85 | return getattr_recursive(self, self.decoder_layers_attr_name) 86 | 87 | def _set_decoder_layers(self, value): 88 | setattr_recursive(self, self.decoder_layers_attr_name, value) 89 | 90 | def init_flamingo( 91 | self, 92 | args, 93 | media_token_id, 94 | eoc_token_id, 95 | lang_hidden_size, 96 | vis_hidden_size, 97 | cross_attn_every_n_layers, 98 | gradient_checkpointing, 99 | ): 100 | """ 101 | Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. 102 | """ 103 | self.old_decoder_blocks = self._get_decoder_layers() 104 | 105 | if args.baseline_no_image: 106 | self.gated_cross_attn_layers = nn.ModuleList( 107 | [ 108 | None 109 | for layer_idx, _ in enumerate(self._get_decoder_layers()) 110 | ] 111 | ) 112 | else: 113 | self.gated_cross_attn_layers = nn.ModuleList( 114 | [ 115 | GatedCrossAttentionBlock( 116 | dim=lang_hidden_size, dim_visual=vis_hidden_size 117 | ) 118 | if (layer_idx + 1) % cross_attn_every_n_layers == 0 119 | else None 120 | for layer_idx, _ in enumerate(self._get_decoder_layers()) 121 | ] 122 | ) 123 | self.init_flamingo_layers(gradient_checkpointing) 124 | self.media_token_id = media_token_id 125 | self.eoc_token_id = eoc_token_id 126 | self.initialized_flamingo = True 127 | self._use_cached_vision_x = False 128 | 129 | 130 | def init_flamingo_layers(self, gradient_checkpointing): 131 | """ 132 | Re initializes the FlamingoLayers. 133 | Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks 134 | """ 135 | self._set_decoder_layers( 136 | nn.ModuleList( 137 | [ 138 | FlamingoLayer( 139 | gated_cross_attn_layer, decoder_layer, gradient_checkpointing 140 | ) 141 | for gated_cross_attn_layer, decoder_layer in zip( 142 | self.gated_cross_attn_layers, self.old_decoder_blocks 143 | ) 144 | ] 145 | ) 146 | ) 147 | 148 | #added by Hee Suk Yoon 10/7/23 149 | def _encode_special_tokens(self): 150 | flamingo_layers = self._get_decoder_layers() 151 | for flamingo_layer in flamingo_layers: 152 | flamingo_layer.media_token_id = self.media_token_id 153 | flamingo_layer.eoc_token_id = self.eoc_token_id 154 | flamingo_layer.media_end_token_id = self.media_end_token_id 155 | return 156 | 157 | def forward(self, input_ids, attention_mask, mask_map, mask_image_map, **kwargs): 158 | """Condition the Flamingo layers on the media locations before forward()""" 159 | if not self.initialized_flamingo: 160 | raise ValueError( 161 | "Flamingo layers are not initialized. Please call `init_flamingo` first." 162 | ) 163 | media_locations = input_ids == self.media_token_id 164 | 165 | # if there are media already cached and we're generating and there are no media tokens in the input, 166 | # we'll assume that ALL input tokens should attend to the last previous media that is cached. 167 | # this is especially important for HF generate() compatibility, since generate() calls forward() 168 | # repeatedly one token at a time (with no media tokens). 169 | # without this check, the model would not attend to any images when generating (after the first token) 170 | use_cached_media_locations = ( 171 | self._use_cached_vision_x 172 | and self.is_conditioned() 173 | and not media_locations.any() 174 | ) 175 | 176 | for layer in self._get_decoder_layers(): 177 | if not use_cached_media_locations: 178 | layer.condition_media_locations(media_locations) 179 | layer.condition_use_cached_media(use_cached_media_locations) 180 | 181 | # package arguments for the other parent's forward. since we don't know the order of the arguments, 182 | # make them all kwargs 183 | kwargs["input_ids"] = input_ids 184 | kwargs["attention_mask"] = attention_mask 185 | kwargs["mask_map"] = mask_map 186 | kwargs["mask_image_map"] = mask_image_map 187 | return super().forward(**kwargs) # Call the other parent's forward method 188 | 189 | def is_conditioned(self) -> bool: 190 | """Check whether all decoder layers are already conditioned.""" 191 | return all(l.is_conditioned() for l in self._get_decoder_layers()) 192 | 193 | def clear_conditioned_layers(self): 194 | for layer in self._get_decoder_layers(): 195 | layer.condition_vis_x(None) 196 | layer.condition_media_locations(None) 197 | layer.condition_use_cached_media(None) 198 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for initializing webdataset objects 3 | """ 4 | 5 | import ast 6 | import json 7 | import logging 8 | import os 9 | import random 10 | import sys 11 | from dataclasses import dataclass 12 | from multiprocessing import Value 13 | 14 | import braceexpand 15 | import numpy as np 16 | import webdataset as wds 17 | from PIL import Image 18 | from torch.utils.data import DataLoader, IterableDataset, get_worker_info 19 | from torch.utils.data.distributed import DistributedSampler 20 | from webdataset.filters import _shuffle 21 | from webdataset.tariterators import ( 22 | base_plus_ext, 23 | tar_file_expander, 24 | url_opener, 25 | valid_sample, 26 | ) 27 | 28 | try: 29 | import horovod.torch as hvd 30 | except ImportError: 31 | hvd = None 32 | 33 | 34 | class SharedEpoch: 35 | def __init__(self, epoch: int = 0): 36 | self.shared_epoch = Value("i", epoch) 37 | 38 | def set_value(self, epoch): 39 | self.shared_epoch.value = epoch 40 | 41 | def get_value(self): 42 | return self.shared_epoch.value 43 | 44 | 45 | @dataclass 46 | class DataInfo: 47 | dataloader: DataLoader 48 | sampler: DistributedSampler = None 49 | shared_epoch: SharedEpoch = None 50 | 51 | def set_epoch(self, epoch): 52 | if self.shared_epoch is not None: 53 | self.shared_epoch.set_value(epoch) 54 | if self.sampler is not None and isinstance(self.sampler, DistributedSampler): 55 | self.sampler.set_epoch(epoch) 56 | 57 | 58 | def get_dataset_size(shards): 59 | shards_list = list(braceexpand.braceexpand(shards)) 60 | dir_path = os.path.dirname(shards[0]) 61 | sizes_filename = os.path.join(dir_path, "sizes.json") 62 | len_filename = os.path.join(dir_path, "__len__") 63 | if os.path.exists(sizes_filename): 64 | sizes = json.load(open(sizes_filename, "r")) 65 | total_size = sum( 66 | [ 67 | int(sizes[os.path.basename(shard)]) 68 | if os.path.basename(shard) in sizes 69 | else 0 70 | for shard in shards_list 71 | ] 72 | ) 73 | elif os.path.exists(len_filename): 74 | # FIXME this used to be eval(open(...)) but that seemed rather unsafe 75 | total_size = ast.literal_eval(open(len_filename, "r").read()) 76 | else: 77 | total_size = None # num samples undefined 78 | # some common dataset sizes (at time of authors last download) 79 | # CC3M (train): 2905954 80 | # CC12M: 10968539 81 | # LAION-400M: 407332084 82 | # LAION-2B (english): 2170337258 83 | num_shards = len(shards_list) 84 | return total_size, num_shards 85 | 86 | 87 | def count_samples(dataloader): 88 | os.environ["WDS_EPOCH"] = "0" 89 | n_elements, n_batches = 0, 0 90 | for images, texts in dataloader: 91 | n_batches += 1 92 | n_elements += len(images) 93 | assert len(images) == len(texts) 94 | return n_elements, n_batches 95 | 96 | 97 | def log_and_continue(exn): 98 | """Call in an exception handler to ignore any exception, issue a warning, and continue.""" 99 | logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") 100 | return True 101 | 102 | 103 | def group_by_keys_nothrow( 104 | data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None 105 | ): 106 | """Return function over iterator that groups key, value pairs into samples. 107 | 108 | :param keys: function that splits the key into key and extension (base_plus_ext) 109 | :param lcase: convert suffixes to lower case (Default value = True) 110 | """ 111 | current_sample = None 112 | for filesample in data: 113 | assert isinstance(filesample, dict) 114 | fname, value = filesample["fname"], filesample["data"] 115 | prefix, suffix = keys(fname) 116 | if prefix is None: 117 | continue 118 | if lcase: 119 | suffix = suffix.lower() 120 | # FIXME webdataset version throws if suffix in current_sample, but we have a potential for 121 | # this happening in the current LAION400m dataset if a tar ends with same prefix as the next 122 | # begins, rare, but can happen since prefix aren't unique across tar files in that dataset 123 | if ( 124 | current_sample is None 125 | or prefix != current_sample["__key__"] 126 | or suffix in current_sample 127 | ): 128 | if valid_sample(current_sample): 129 | yield current_sample 130 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) 131 | if suffixes is None or suffix in suffixes: 132 | current_sample[suffix] = value 133 | if valid_sample(current_sample): 134 | yield current_sample 135 | 136 | 137 | def tarfile_to_samples_nothrow(src, handler=log_and_continue): 138 | # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw 139 | streams = url_opener(src, handler=handler) 140 | files = tar_file_expander(streams, handler=handler) 141 | samples = group_by_keys_nothrow(files, handler=handler) 142 | return samples 143 | 144 | 145 | def pytorch_worker_seed(increment=0): 146 | """get dataloader worker seed from pytorch""" 147 | worker_info = get_worker_info() 148 | if worker_info is not None: 149 | # favour using the seed already created for pytorch dataloader workers if it exists 150 | seed = worker_info.seed 151 | if increment: 152 | # space out seed increments so they can't overlap across workers in different iterations 153 | seed += increment * max(1, worker_info.num_workers) 154 | return seed 155 | # fallback to wds rank based seed 156 | return wds.utils.pytorch_worker_seed() 157 | 158 | 159 | class detshuffle2(wds.PipelineStage): 160 | def __init__( 161 | self, 162 | bufsize=1000, 163 | initial=100, 164 | seed=0, 165 | epoch=-1, 166 | ): 167 | self.bufsize = bufsize 168 | self.initial = initial 169 | self.seed = seed 170 | self.epoch = epoch 171 | 172 | def run(self, src): 173 | if isinstance(self.epoch, SharedEpoch): 174 | epoch = self.epoch.get_value() 175 | else: 176 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 177 | # situation as different workers may wrap at different times (or not at all). 178 | self.epoch += 1 179 | epoch = self.epoch 180 | rng = random.Random() 181 | if self.seed < 0: 182 | # If seed is negative, we use the worker's seed, this will be different across all nodes/workers 183 | seed = pytorch_worker_seed(epoch) 184 | else: 185 | # This seed to be deterministic AND the same across all nodes/workers in each epoch 186 | seed = self.seed + epoch 187 | rng.seed(seed) 188 | return _shuffle(src, self.bufsize, self.initial, rng) 189 | 190 | 191 | class ResampledShards2(IterableDataset): 192 | """An iterable dataset yielding a list of urls.""" 193 | 194 | def __init__( 195 | self, 196 | urls, 197 | nshards=sys.maxsize, 198 | worker_seed=None, 199 | deterministic=False, 200 | epoch=-1, 201 | ): 202 | """Sample shards from the shard list with replacement. 203 | :param urls: a list of URLs as a Python list or brace notation string 204 | """ 205 | super().__init__() 206 | urls = wds.shardlists.expand_urls(urls) 207 | self.urls = urls 208 | assert isinstance(self.urls[0], str) 209 | self.nshards = nshards 210 | self.rng = random.Random() 211 | self.worker_seed = worker_seed 212 | self.deterministic = deterministic 213 | self.epoch = epoch 214 | 215 | def __iter__(self): 216 | """Return an iterator over the shards.""" 217 | if isinstance(self.epoch, SharedEpoch): 218 | epoch = self.epoch.get_value() 219 | else: 220 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 221 | # situation as different workers may wrap at different times (or not at all). 222 | self.epoch += 1 223 | epoch = self.epoch 224 | 225 | if self.deterministic: 226 | # reset seed w/ epoch if deterministic 227 | if self.worker_seed is None: 228 | # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id 229 | seed = pytorch_worker_seed(epoch) 230 | else: 231 | seed = self.worker_seed() + epoch 232 | self.rng.seed(seed) 233 | for _ in range(self.nshards): 234 | yield dict(url=self.rng.choice(self.urls)) 235 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/src/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | from torch import einsum, nn 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | 25 | class PerceiverAttention(nn.Module): 26 | def __init__(self, *, dim, dim_head=64, heads=8): 27 | super().__init__() 28 | self.scale = dim_head**-0.5 29 | self.heads = heads 30 | inner_dim = dim_head * heads 31 | 32 | self.norm_media = nn.LayerNorm(dim) 33 | self.norm_latents = nn.LayerNorm(dim) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 38 | 39 | def forward(self, x, latents): 40 | """ 41 | Args: 42 | x (torch.Tensor): image features 43 | shape (b, T, n1, D) 44 | latent (torch.Tensor): latent features 45 | shape (b, T, n2, D) 46 | """ 47 | x = self.norm_media(x) 48 | latents = self.norm_latents(latents) 49 | 50 | h = self.heads 51 | 52 | q = self.to_q(latents) 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 56 | q = q * self.scale 57 | 58 | # attention 59 | sim = einsum("... i d, ... j d -> ... i j", q, k) 60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 61 | attn = sim.softmax(dim=-1) 62 | 63 | out = einsum("... i j, ... j d -> ... i d", attn, v) 64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 65 | return self.to_out(out) 66 | 67 | 68 | class PerceiverResampler(nn.Module): 69 | def __init__( 70 | self, 71 | *, 72 | dim, 73 | depth=6, 74 | dim_head=64, 75 | heads=8, 76 | num_latents=64, 77 | max_num_media=None, 78 | max_num_frames=None, 79 | ff_mult=4, 80 | ): 81 | super().__init__() 82 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 83 | self.frame_embs = ( 84 | nn.Parameter(torch.randn(max_num_frames, dim)) 85 | if exists(max_num_frames) 86 | else None 87 | ) 88 | self.media_time_embs = ( 89 | nn.Parameter(torch.randn(max_num_media, 1, dim)) 90 | if exists(max_num_media) 91 | else None 92 | ) 93 | 94 | self.layers = nn.ModuleList([]) 95 | for _ in range(depth): 96 | self.layers.append( 97 | nn.ModuleList( 98 | [ 99 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 100 | FeedForward(dim=dim, mult=ff_mult), 101 | ] 102 | ) 103 | ) 104 | 105 | self.norm = nn.LayerNorm(dim) 106 | 107 | def forward(self, x): 108 | """ 109 | Args: 110 | x (torch.Tensor): image features 111 | shape (b, T, F, v, D) 112 | Returns: 113 | shape (b, T, n, D) where n is self.num_latents 114 | """ 115 | b, T, F, v = x.shape[:4] 116 | 117 | # frame and media time embeddings 118 | if exists(self.frame_embs): 119 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 120 | x = x + frame_embs 121 | x = rearrange( 122 | x, "b T F v d -> b T (F v) d" 123 | ) # flatten the frame and spatial dimensions 124 | if exists(self.media_time_embs): 125 | x = x + self.media_time_embs[:T] 126 | 127 | # blocks 128 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 129 | for attn, ff in self.layers: 130 | latents = attn(x, latents) + latents 131 | latents = ff(latents) + latents 132 | return self.norm(latents) 133 | 134 | 135 | # gated cross attention 136 | class MaskedCrossAttention(nn.Module): 137 | def __init__( 138 | self, 139 | *, 140 | dim, 141 | dim_visual, 142 | dim_head=64, 143 | heads=8, 144 | only_attend_immediate_media=True, 145 | ): 146 | super().__init__() 147 | self.scale = dim_head**-0.5 148 | self.heads = heads 149 | inner_dim = dim_head * heads 150 | 151 | self.norm = nn.LayerNorm(dim) 152 | 153 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 154 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) 155 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 156 | 157 | # whether for text to only attend to immediate preceding image, or all previous images 158 | self.only_attend_immediate_media = only_attend_immediate_media 159 | 160 | def forward(self, x, media, media_locations=None, use_cached_media=False): 161 | """ 162 | Args: 163 | x (torch.Tensor): text features 164 | shape (B, T_txt, D_txt) 165 | media (torch.Tensor): image features 166 | shape (B, T_img, n, D_img) where n is the dim of the latents 167 | media_locations: boolean mask identifying the media tokens in x 168 | shape (B, T_txt) 169 | use_cached_media: bool 170 | If true, treat all of x as if they occur after the last media 171 | registered in media_locations. T_txt does not need to exactly 172 | equal media_locations.shape[1] in this case 173 | """ 174 | 175 | if not use_cached_media: 176 | assert ( 177 | media_locations.shape[1] == x.shape[1] 178 | ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" 179 | 180 | T_txt = x.shape[1] 181 | _, T_img, n = media.shape[:3] 182 | h = self.heads 183 | 184 | x = self.norm(x) 185 | 186 | q = self.to_q(x) 187 | media = rearrange(media, "b t n d -> b (t n) d") 188 | 189 | k, v = self.to_kv(media).chunk(2, dim=-1) 190 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 191 | 192 | q = q * self.scale 193 | 194 | sim = einsum("... i d, ... j d -> ... i j", q, k) 195 | 196 | if exists(media_locations): 197 | media_time = torch.arange(T_img, device=x.device) + 1 198 | 199 | if use_cached_media: 200 | # text time is set to the last cached media location 201 | text_time = repeat( 202 | torch.count_nonzero(media_locations, dim=1), 203 | "b -> b i", 204 | i=T_txt, 205 | ) 206 | else: 207 | # at each boolean of True, increment the time counter (relative to media time) 208 | text_time = media_locations.cumsum(dim=-1) 209 | 210 | # text time must equal media time if only attending to most immediate image 211 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 212 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 213 | 214 | text_to_media_mask = mask_op( 215 | rearrange(text_time, "b i -> b 1 i 1"), 216 | repeat(media_time, "j -> 1 1 1 (j n)", n=n), 217 | ) 218 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 219 | 220 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 221 | attn = sim.softmax(dim=-1) 222 | 223 | if exists(media_locations) and self.only_attend_immediate_media: 224 | # any text without a preceding media needs to have attention zeroed out 225 | text_without_media_mask = text_time == 0 226 | text_without_media_mask = rearrange( 227 | text_without_media_mask, "b i -> b 1 i 1" 228 | ) 229 | attn = attn.masked_fill(text_without_media_mask, 0.0) 230 | 231 | out = einsum("... i j, ... j d -> ... i d", attn, v) 232 | out = rearrange(out, "b h n d -> b n (h d)") 233 | return self.to_out(out) 234 | 235 | 236 | class GatedCrossAttentionBlock(nn.Module): 237 | def __init__( 238 | self, 239 | *, 240 | dim, 241 | dim_visual, 242 | dim_head=64, 243 | heads=8, 244 | ff_mult=4, 245 | only_attend_immediate_media=True, 246 | ): 247 | super().__init__() 248 | self.attn = MaskedCrossAttention( 249 | dim=dim, 250 | dim_visual=dim_visual, 251 | dim_head=dim_head, 252 | heads=heads, 253 | only_attend_immediate_media=only_attend_immediate_media, 254 | ) 255 | self.attn_gate = nn.Parameter(torch.tensor([0.0])) 256 | 257 | self.ff = FeedForward(dim, mult=ff_mult) 258 | self.ff_gate = nn.Parameter(torch.tensor([0.0])) 259 | 260 | def forward( 261 | self, 262 | x, 263 | media, 264 | media_locations=None, 265 | use_cached_media=False, 266 | ): 267 | x = ( 268 | self.attn( 269 | x, 270 | media, 271 | media_locations=media_locations, 272 | use_cached_media=use_cached_media, 273 | ) 274 | * self.attn_gate.tanh() 275 | + x 276 | ) 277 | x = self.ff(x) * self.ff_gate.tanh() + x 278 | 279 | return x 280 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/custom_files/custom_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | from torch import einsum, nn 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | import ipdb 25 | 26 | class PerceiverAttention(nn.Module): 27 | def __init__(self, *, dim, dim_head=64, heads=8): 28 | super().__init__() 29 | self.scale = dim_head**-0.5 30 | self.heads = heads 31 | inner_dim = dim_head * heads 32 | 33 | self.norm_media = nn.LayerNorm(dim) 34 | self.norm_latents = nn.LayerNorm(dim) 35 | 36 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 37 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 38 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 39 | 40 | def forward(self, x, latents): 41 | """ 42 | Args: 43 | x (torch.Tensor): image features 44 | shape (b, T, n1, D) 45 | latent (torch.Tensor): latent features 46 | shape (b, T, n2, D) 47 | """ 48 | x = self.norm_media(x) 49 | latents = self.norm_latents(latents) 50 | 51 | h = self.heads 52 | 53 | q = self.to_q(latents) 54 | kv_input = torch.cat((x, latents), dim=-2) 55 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 56 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 57 | q = q * self.scale 58 | 59 | # attention 60 | sim = einsum("... i d, ... j d -> ... i j", q, k) 61 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 62 | attn = sim.softmax(dim=-1) 63 | 64 | out = einsum("... i j, ... j d -> ... i d", attn, v) 65 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 66 | return self.to_out(out) 67 | 68 | 69 | class PerceiverResampler(nn.Module): 70 | def __init__( 71 | self, 72 | *, 73 | dim, 74 | depth=6, 75 | dim_head=64, 76 | heads=8, 77 | num_latents=64, 78 | max_num_media=None, 79 | max_num_frames=None, 80 | ff_mult=4, 81 | ): 82 | super().__init__() 83 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 84 | self.frame_embs = ( 85 | nn.Parameter(torch.randn(max_num_frames, dim)) 86 | if exists(max_num_frames) 87 | else None 88 | ) 89 | self.media_time_embs = ( 90 | nn.Parameter(torch.randn(max_num_media, 1, dim)) 91 | if exists(max_num_media) 92 | else None 93 | ) 94 | 95 | self.layers = nn.ModuleList([]) 96 | for _ in range(depth): 97 | self.layers.append( 98 | nn.ModuleList( 99 | [ 100 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 101 | FeedForward(dim=dim, mult=ff_mult), 102 | ] 103 | ) 104 | ) 105 | 106 | self.norm = nn.LayerNorm(dim) 107 | 108 | def forward(self, x): 109 | """ 110 | Args: 111 | x (torch.Tensor): image features 112 | shape (b, T, F, v, D) 113 | Returns: 114 | shape (b, T, n, D) where n is self.num_latents 115 | """ 116 | b, T, F, v = x.shape[:4] 117 | 118 | # frame and media time embeddings 119 | if exists(self.frame_embs): 120 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 121 | x = x + frame_embs 122 | x = rearrange( 123 | x, "b T F v d -> b T (F v) d" 124 | ) # flatten the frame and spatial dimensions 125 | if exists(self.media_time_embs): 126 | x = x + self.media_time_embs[:T] 127 | 128 | # blocks 129 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 130 | for attn, ff in self.layers: 131 | latents = attn(x, latents) + latents 132 | latents = ff(latents) + latents 133 | return self.norm(latents) 134 | 135 | 136 | # gated cross attention 137 | class MaskedCrossAttention(nn.Module): 138 | def __init__( 139 | self, 140 | *, 141 | dim, 142 | dim_visual, 143 | dim_head=64, 144 | heads=8, 145 | only_attend_immediate_media=False, 146 | ): 147 | super().__init__() 148 | self.scale = dim_head**-0.5 149 | self.heads = heads 150 | inner_dim = dim_head * heads 151 | 152 | self.norm = nn.LayerNorm(dim) 153 | 154 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 155 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) 156 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 157 | 158 | # whether for text to only attend to immediate preceding image, or all previous images 159 | self.only_attend_immediate_media = only_attend_immediate_media 160 | 161 | def forward(self, x, media, media_locations=None, use_cached_media=False, 162 | image_map=None, custom_text_to_image_mask=None): 163 | """ 164 | Args: 165 | x (torch.Tensor): text features 166 | shape (B, T_txt, D_txt) 167 | media (torch.Tensor): image features 168 | shape (B, T_img, n, D_img) where n is the dim of the latents 169 | media_locations: boolean mask identifying the media tokens in x 170 | shape (B, T_txt) 171 | use_cached_media: bool 172 | If true, treat all of x as if they occur after the last media 173 | registered in media_locations. T_txt does not need to exactly 174 | equal media_locations.shape[1] in this case 175 | """ 176 | 177 | if not use_cached_media: 178 | assert ( 179 | media_locations.shape[1] == x.shape[1] 180 | ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" 181 | 182 | T_txt = x.shape[1] # input length 183 | _, T_img, n = media.shape[:3] 184 | h = self.heads 185 | 186 | x = self.norm(x) 187 | 188 | q = self.to_q(x) 189 | media = rearrange(media, "b t n d -> b (t n) d") 190 | 191 | k, v = self.to_kv(media).chunk(2, dim=-1) 192 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 193 | 194 | q = q * self.scale 195 | 196 | sim = einsum("... i d, ... j d -> ... i j", q, k) 197 | 198 | if exists(media_locations) and custom_text_to_image_mask is None: 199 | media_time = torch.arange(T_img, device=x.device) + 1 200 | 201 | if use_cached_media: 202 | # text time is set to the last cached media location 203 | text_time = repeat( 204 | torch.count_nonzero(media_locations, dim=1), 205 | "b -> b i", 206 | i=T_txt, 207 | ) 208 | else: 209 | # at each boolean of True, increment the time counter (relative to media time) 210 | text_time = media_locations.cumsum(dim=-1) 211 | 212 | # text time must equal media time if only attending to most immediate image 213 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 214 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 215 | 216 | text_to_media_mask = mask_op( 217 | rearrange(text_time, "b i -> b 1 i 1"), 218 | repeat(media_time, "j -> 1 1 1 (j n)", n=n), 219 | ) 220 | ipdb.set_trace() 221 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 222 | else: 223 | sim = sim.masked_fill(~custom_text_to_image_mask, -torch.finfo(sim.dtype).max) 224 | 225 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 226 | attn = sim.softmax(dim=-1) 227 | 228 | if exists(media_locations) and self.only_attend_immediate_media: 229 | # any text without a preceding media needs to have attention zeroed out 230 | text_without_media_mask = text_time == 0 231 | text_without_media_mask = rearrange( 232 | text_without_media_mask, "b i -> b 1 i 1" 233 | ) 234 | attn = attn.masked_fill(text_without_media_mask, 0.0) 235 | 236 | out = einsum("... i j, ... j d -> ... i d", attn, v) 237 | out = rearrange(out, "b h n d -> b n (h d)") 238 | return self.to_out(out) 239 | 240 | 241 | class GatedCrossAttentionBlock(nn.Module): 242 | def __init__( 243 | self, 244 | *, 245 | dim, 246 | dim_visual, 247 | dim_head=64, 248 | heads=8, 249 | ff_mult=4, 250 | only_attend_immediate_media=False, 251 | ): 252 | super().__init__() 253 | self.attn = MaskedCrossAttention( 254 | dim=dim, 255 | dim_visual=dim_visual, 256 | dim_head=dim_head, 257 | heads=heads, 258 | only_attend_immediate_media=only_attend_immediate_media, 259 | ) 260 | self.attn_gate = nn.Parameter(torch.tensor([0.0])) 261 | 262 | self.ff = FeedForward(dim, mult=ff_mult) 263 | self.ff_gate = nn.Parameter(torch.tensor([0.0])) 264 | 265 | def forward( 266 | self, 267 | x, 268 | media, 269 | media_locations=None, 270 | use_cached_media=False, 271 | custom_text_to_image_mask=None, 272 | ): 273 | x = ( 274 | self.attn( 275 | x, 276 | media, 277 | media_locations=media_locations, 278 | use_cached_media=use_cached_media, 279 | custom_text_to_image_mask=custom_text_to_image_mask, 280 | ) 281 | * self.attn_gate.tanh() 282 | + x 283 | ) 284 | x = self.ff(x) * self.ff_gate.tanh() + x 285 | 286 | return x 287 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/scripts/cache_rices_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cache CLIP features for all images in training split in preparation for RICES 3 | """ 4 | import argparse 5 | import sys 6 | import os 7 | 8 | sys.path.append( 9 | os.path.join( 10 | os.path.dirname(os.path.abspath(__file__)), 11 | "..", 12 | ) 13 | ) 14 | from eval.rices import RICES 15 | from eval.eval_datasets import ( 16 | CaptionDataset, 17 | VQADataset, 18 | ImageNetDataset, 19 | HatefulMemesDataset, 20 | ) 21 | import os 22 | import torch 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--output_dir", 27 | type=str, 28 | required=True, 29 | help="Directory to save the cached features.", 30 | ) 31 | parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) 32 | parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) 33 | parser.add_argument("--batch_size", default=256) 34 | 35 | # Per-dataset flags 36 | parser.add_argument( 37 | "--eval_coco", 38 | action="store_true", 39 | default=False, 40 | help="Whether to cache COCO.", 41 | ) 42 | parser.add_argument( 43 | "--eval_vqav2", 44 | action="store_true", 45 | default=False, 46 | help="Whether to cache VQAV2.", 47 | ) 48 | parser.add_argument( 49 | "--eval_ok_vqa", 50 | action="store_true", 51 | default=False, 52 | help="Whether to cache OK-VQA.", 53 | ) 54 | parser.add_argument( 55 | "--eval_vizwiz", 56 | action="store_true", 57 | default=False, 58 | help="Whether to cache VizWiz.", 59 | ) 60 | parser.add_argument( 61 | "--eval_textvqa", 62 | action="store_true", 63 | default=False, 64 | help="Whether to cache TextVQA.", 65 | ) 66 | parser.add_argument( 67 | "--eval_imagenet", 68 | action="store_true", 69 | default=False, 70 | help="Whether to cache ImageNet.", 71 | ) 72 | parser.add_argument( 73 | "--eval_flickr30", 74 | action="store_true", 75 | default=False, 76 | help="Whether to cache Flickr30.", 77 | ) 78 | parser.add_argument( 79 | "--eval_hateful_memes", 80 | action="store_true", 81 | default=False, 82 | help="Whether to cache Hateful Memes.", 83 | ) 84 | 85 | # Dataset arguments 86 | 87 | ## Flickr30 Dataset 88 | parser.add_argument( 89 | "--flickr_image_dir_path", 90 | type=str, 91 | help="Path to the flickr30/flickr30k_images directory.", 92 | default=None, 93 | ) 94 | parser.add_argument( 95 | "--flickr_karpathy_json_path", 96 | type=str, 97 | help="Path to the dataset_flickr30k.json file.", 98 | default=None, 99 | ) 100 | parser.add_argument( 101 | "--flickr_annotations_json_path", 102 | type=str, 103 | help="Path to the dataset_flickr30k_coco_style.json file.", 104 | ) 105 | ## COCO Dataset 106 | parser.add_argument( 107 | "--coco_train_image_dir_path", 108 | type=str, 109 | default=None, 110 | ) 111 | parser.add_argument( 112 | "--coco_val_image_dir_path", 113 | type=str, 114 | default=None, 115 | ) 116 | parser.add_argument( 117 | "--coco_karpathy_json_path", 118 | type=str, 119 | default=None, 120 | ) 121 | parser.add_argument( 122 | "--coco_annotations_json_path", 123 | type=str, 124 | default=None, 125 | ) 126 | 127 | ## VQAV2 Dataset 128 | parser.add_argument( 129 | "--vqav2_train_image_dir_path", 130 | type=str, 131 | default=None, 132 | ) 133 | parser.add_argument( 134 | "--vqav2_train_questions_json_path", 135 | type=str, 136 | default=None, 137 | ) 138 | parser.add_argument( 139 | "--vqav2_train_annotations_json_path", 140 | type=str, 141 | default=None, 142 | ) 143 | 144 | ## OK-VQA Dataset 145 | parser.add_argument( 146 | "--ok_vqa_train_image_dir_path", 147 | type=str, 148 | help="Path to the vqav2/train2014 directory.", 149 | default=None, 150 | ) 151 | parser.add_argument( 152 | "--ok_vqa_train_questions_json_path", 153 | type=str, 154 | help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.", 155 | default=None, 156 | ) 157 | parser.add_argument( 158 | "--ok_vqa_train_annotations_json_path", 159 | type=str, 160 | help="Path to the v2_mscoco_train2014_annotations.json file.", 161 | default=None, 162 | ) 163 | 164 | ## VizWiz Dataset 165 | parser.add_argument( 166 | "--vizwiz_train_image_dir_path", 167 | type=str, 168 | help="Path to the vizwiz train images directory.", 169 | default=None, 170 | ) 171 | parser.add_argument( 172 | "--vizwiz_train_questions_json_path", 173 | type=str, 174 | help="Path to the vizwiz questions json file.", 175 | default=None, 176 | ) 177 | parser.add_argument( 178 | "--vizwiz_train_annotations_json_path", 179 | type=str, 180 | help="Path to the vizwiz annotations json file.", 181 | default=None, 182 | ) 183 | 184 | # TextVQA Dataset 185 | parser.add_argument( 186 | "--textvqa_image_dir_path", 187 | type=str, 188 | help="Path to the textvqa images directory.", 189 | default=None, 190 | ) 191 | parser.add_argument( 192 | "--textvqa_train_questions_json_path", 193 | type=str, 194 | help="Path to the textvqa questions json file.", 195 | default=None, 196 | ) 197 | parser.add_argument( 198 | "--textvqa_train_annotations_json_path", 199 | type=str, 200 | help="Path to the textvqa annotations json file.", 201 | default=None, 202 | ) 203 | 204 | 205 | ## Imagenet dataset 206 | parser.add_argument("--imagenet_root", type=str, default="/tmp") 207 | 208 | ## Hateful Memes dataset 209 | parser.add_argument( 210 | "--hateful_memes_image_dir_path", 211 | type=str, 212 | default=None, 213 | ) 214 | parser.add_argument( 215 | "--hateful_memes_train_annotations_json_path", 216 | type=str, 217 | default=None, 218 | ) 219 | 220 | 221 | def main(): 222 | args, leftovers = parser.parse_known_args() 223 | device_id = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" 224 | if args.eval_flickr30: 225 | print("Caching Flickr30k...") 226 | train_dataset = CaptionDataset( 227 | image_train_dir_path=args.flickr_image_dir_path, 228 | image_val_dir_path=None, 229 | annotations_path=args.flickr_karpathy_json_path, 230 | is_train=True, 231 | dataset_name="flickr", 232 | ) 233 | rices_dataset = RICES( 234 | train_dataset, 235 | device_id, 236 | args.batch_size, 237 | vision_encoder_path=args.vision_encoder_path, 238 | vision_encoder_pretrained=args.vision_encoder_pretrained, 239 | ) 240 | torch.save( 241 | rices_dataset.features, 242 | os.path.join(args.output_dir, "flickr30.pkl"), 243 | ) 244 | 245 | if args.eval_coco: 246 | print("Caching COCO...") 247 | train_dataset = CaptionDataset( 248 | image_train_dir_path=args.coco_train_image_dir_path, 249 | image_val_dir_path=args.coco_val_image_dir_path, 250 | annotations_path=args.coco_karpathy_json_path, 251 | is_train=True, 252 | dataset_name="coco", 253 | ) 254 | rices_dataset = RICES( 255 | train_dataset, 256 | device_id, 257 | args.batch_size, 258 | vision_encoder_path=args.vision_encoder_path, 259 | vision_encoder_pretrained=args.vision_encoder_pretrained, 260 | ) 261 | torch.save( 262 | rices_dataset.features, 263 | os.path.join(args.output_dir, "coco.pkl"), 264 | ) 265 | 266 | if args.eval_ok_vqa: 267 | print("Caching OK-VQA...") 268 | train_dataset = VQADataset( 269 | image_dir_path=args.ok_vqa_train_image_dir_path, 270 | question_path=args.ok_vqa_train_questions_json_path, 271 | annotations_path=args.ok_vqa_train_annotations_json_path, 272 | is_train=True, 273 | dataset_name="ok_vqa", 274 | ) 275 | rices_dataset = RICES( 276 | train_dataset, 277 | device_id, 278 | args.batch_size, 279 | vision_encoder_path=args.vision_encoder_path, 280 | vision_encoder_pretrained=args.vision_encoder_pretrained, 281 | ) 282 | torch.save( 283 | rices_dataset.features, 284 | os.path.join(args.output_dir, "ok_vqa.pkl"), 285 | ) 286 | 287 | if args.eval_vizwiz: 288 | print("Caching VizWiz...") 289 | train_dataset = VQADataset( 290 | image_dir_path=args.vizwiz_train_image_dir_path, 291 | question_path=args.vizwiz_train_questions_json_path, 292 | annotations_path=args.vizwiz_train_annotations_json_path, 293 | is_train=True, 294 | dataset_name="vizwiz", 295 | ) 296 | rices_dataset = RICES( 297 | train_dataset, 298 | device_id, 299 | args.batch_size, 300 | vision_encoder_path=args.vision_encoder_path, 301 | vision_encoder_pretrained=args.vision_encoder_pretrained, 302 | ) 303 | torch.save( 304 | rices_dataset.features, 305 | os.path.join(args.output_dir, "vizwiz.pkl"), 306 | ) 307 | 308 | if args.eval_vqav2: 309 | print("Caching VQAv2...") 310 | train_dataset = VQADataset( 311 | image_dir_path=args.vqav2_train_image_dir_path, 312 | question_path=args.vqav2_train_questions_json_path, 313 | annotations_path=args.vqav2_train_annotations_json_path, 314 | is_train=True, 315 | dataset_name="vqav2", 316 | ) 317 | rices_dataset = RICES( 318 | train_dataset, 319 | device_id, 320 | args.batch_size, 321 | vision_encoder_path=args.vision_encoder_path, 322 | vision_encoder_pretrained=args.vision_encoder_pretrained, 323 | ) 324 | torch.save( 325 | rices_dataset.features, 326 | os.path.join(args.output_dir, "vqav2.pkl"), 327 | ) 328 | 329 | if args.eval_textvqa: 330 | print("Caching TextVQA...") 331 | train_dataset = VQADataset( 332 | image_dir_path=args.textvqa_image_dir_path, 333 | question_path=args.textvqa_train_questions_json_path, 334 | annotations_path=args.textvqa_train_annotations_json_path, 335 | is_train=True, 336 | dataset_name="textvqa", 337 | ) 338 | rices_dataset = RICES( 339 | train_dataset, 340 | device_id, 341 | args.batch_size, 342 | vision_encoder_path=args.vision_encoder_path, 343 | vision_encoder_pretrained=args.vision_encoder_pretrained, 344 | ) 345 | torch.save( 346 | rices_dataset.features, 347 | os.path.join(args.output_dir, "textvqa.pkl"), 348 | ) 349 | 350 | if args.eval_hateful_memes: 351 | print("Caching Hateful Memes...") 352 | train_dataset = HatefulMemesDataset( 353 | image_dir_path=args.hateful_memes_image_dir_path, 354 | annotations_path=args.hateful_memes_train_annotations_json_path, 355 | ) 356 | rices_dataset = RICES( 357 | train_dataset, 358 | device_id, 359 | args.batch_size, 360 | vision_encoder_path=args.vision_encoder_path, 361 | vision_encoder_pretrained=args.vision_encoder_pretrained, 362 | ) 363 | torch.save( 364 | rices_dataset.features, 365 | os.path.join(args.output_dir, "hateful_memes.pkl"), 366 | ) 367 | 368 | 369 | if __name__ == "__main__": 370 | main() 371 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess and load datasets for training. 3 | """ 4 | 5 | import functools 6 | import io 7 | import json 8 | import math 9 | import re 10 | import random 11 | import numpy as np 12 | import torch 13 | import torchvision 14 | import webdataset as wds 15 | from PIL import Image 16 | import base64 17 | from scipy.optimize import linear_sum_assignment 18 | 19 | from data_utils import * 20 | import ipdb 21 | 22 | Image.MAX_IMAGE_PIXELS = 1000000000 23 | N_CHANNELS = 3 24 | MIN_KB = 10 25 | _SHARD_SHUFFLE_SIZE = 2000 26 | _SHARD_SHUFFLE_INITIAL = 500 27 | _SAMPLE_SHUFFLE_SIZE = 5000 28 | _SAMPLE_SHUFFLE_INITIAL = 1000 29 | 30 | try: 31 | import horovod.torch as hvd 32 | except ImportError: 33 | hvd = None 34 | 35 | 36 | def preprocess_image(sample, image_processor): 37 | """ 38 | Convert images to tensors for training. 39 | Augmentations: random horizontal flip. 40 | Normalization handled by wds. 41 | """ 42 | image = [image_processor(s).unsqueeze(0) for s in sample] 43 | image = torch.cat(image, dim=0) 44 | image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image) 45 | return image 46 | 47 | 48 | def filter_no_caption_or_no_image(sample): 49 | """ 50 | Filter out LAION samples with no caption or no image. 51 | """ 52 | return ("txt" in sample) and ( 53 | "png" in sample or "jpg" in sample or "jpeg" in sample 54 | ) 55 | 56 | 57 | def preprocess_laion_text(sample, tokenizer, max_tokens=32): 58 | """ 59 | Preprocess text for LAION. 60 | Captions are truncated to 32 tokens by default. 61 | """ 62 | tokenizer.padding_side = "right" 63 | sample = [ 64 | (f"{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample 65 | ] 66 | text = tokenizer( 67 | sample, 68 | max_length=max_tokens, 69 | padding="longest", 70 | truncation="only_first", 71 | return_tensors="pt", 72 | ) 73 | return text["input_ids"], text["attention_mask"] 74 | 75 | def preprocess_mmdialogue_text(sample, tokenizer, max_tokens=256): 76 | """ 77 | Preprocess text for LAION. 78 | Captions are truncated to 32 tokens by default. 79 | """ 80 | tokenizer.padding_side = "right" 81 | sample = [s.strip() for s in sample] 82 | text = tokenizer( 83 | sample, 84 | max_length=max_tokens, 85 | padding="longest", 86 | truncation="only_first", 87 | return_tensors="pt", 88 | ) 89 | 90 | # list of number of image tokens in each sample 91 | image_token_id = tokenizer.additional_special_tokens_ids[ 92 | tokenizer.additional_special_tokens.index("") 93 | ] 94 | num_image_tokens = torch.count_nonzero(text["input_ids"] == image_token_id, dim=1) 95 | 96 | return text["input_ids"], text["attention_mask"], num_image_tokens 97 | 98 | 99 | def get_mmdialogue(args, image_processor, tokenizer, epoch=0, floor=False, split='test'): # added by HSY 10/3/23 100 | import ipdb 101 | import pickle 102 | import os 103 | from torch.utils.data import DataLoader, Dataset 104 | import cv2 105 | from torchvision.transforms import functional as TF 106 | import numpy as np 107 | from dataclasses import dataclass 108 | from torch.utils.data.distributed import DistributedSampler 109 | 110 | 111 | # open pickle dataset ----- 112 | def open_pickle(data_path, split): 113 | #open data pickle file at data_path 114 | with open(os.path.join(data_path,split + '.pkl'), 'rb') as f: 115 | data = pickle.load(f) 116 | return data 117 | 118 | # get data 119 | data = open_pickle(args.data_path, split) 120 | num_samples = len(data['text']) 121 | 122 | # create a shared epoch store to sync epoch to dataloader worker proc 123 | shared_epoch = SharedEpoch(epoch=epoch) 124 | 125 | # create two preprocess functions that take in the passed in image_processor and tokenizer 126 | preprocess_image_fn = functools.partial(preprocess_image, image_processor=image_processor) 127 | preprocess_text_fn = functools.partial(preprocess_mmdialogue_text, tokenizer=tokenizer) 128 | 129 | class MMDataset(Dataset): 130 | def __init__(self, data, split): 131 | self.data = data 132 | self.split = split 133 | 134 | # Precompute image paths 135 | self.image_paths = [] 136 | for image_ids in self.data['image_id']: 137 | paths = [os.path.join(args.image_path, self.split, image_id) for image_id in image_ids] 138 | self.image_paths.append(paths) 139 | 140 | def __len__(self): 141 | return len(self.data['text']) 142 | 143 | 144 | def __getitem__(self, idx): 145 | # Load images using OpenCV 146 | images = [cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) for path in self.image_paths[idx]] 147 | # Convert to PIL Images for torchvision compatibility 148 | images = [TF.to_pil_image(img) for img in images] 149 | text = self.data['text'][idx] 150 | return (text, images) 151 | 152 | 153 | @dataclass 154 | class DataCollator_custom: 155 | def __init__(self, args, preprocess_image_fn, preprocess_text_fn, tokenizer, split, mask_image_description): #edited by HSY 10/21/23 156 | #self.data = data 157 | self.preprocess_image_fn = preprocess_image_fn 158 | self.preprocess_text_fn = preprocess_text_fn 159 | self.media_token_id = tokenizer.encode("")[-1] 160 | self.media_end_token_id = tokenizer.encode("")[-1] 161 | #self.eos_token_id = tokenizer.encode("<|endofchunk|>")[-1] 162 | self.eos_token_id = tokenizer.encode("")[-1] 163 | self.split = split 164 | 165 | # make mask_map for langauge model 166 | # image_map and image_feature_length is for making the mask for the cross attention for gated attention layer. 167 | def create_mask(self,input_ids, media_token_id, media_end_token_id, eos_token_id, image_map=None, image_feature_length=64): 168 | # Create a mask filled with True values 169 | mask = torch.ones(input_ids.size(0), input_ids.size(1), input_ids.size(1), dtype=torch.bool) 170 | max_images = max(image_map) 171 | mask_image = torch.zeros(input_ids.size(0), input_ids.size(1), image_feature_length*max_images, dtype=torch.bool) 172 | 173 | 174 | # Iterate over each sequence in the batch 175 | for i, sequence in enumerate(input_ids): 176 | mask_temp = torch.ones(input_ids.size(1), dtype=torch.bool) 177 | mask_image_temp = torch.zeros(input_ids.size(1), image_feature_length*max_images, dtype=torch.bool) 178 | # Find the positions of the media tokens 179 | start_positions = (sequence == media_token_id).nonzero(as_tuple=True)[0].tolist() 180 | end_positions = (sequence == media_end_token_id).nonzero(as_tuple=True)[0].tolist() 181 | eos_start_positions = (sequence == eos_token_id).nonzero(as_tuple=True)[0].tolist() 182 | if eos_start_positions: 183 | eos_start_positions = eos_start_positions[0] 184 | else: 185 | eos_start_positions = None 186 | 187 | # image mask creation---- 188 | for ii, end_position in enumerate(end_positions): 189 | if eos_start_positions: 190 | mask_image_temp[end_position+1:eos_start_positions,image_feature_length*(ii):image_feature_length*(ii+1)] = True 191 | else: 192 | mask_image_temp[end_position+1:,image_feature_length*(ii):image_feature_length*(ii+1)] = True 193 | #------------------------ 194 | 195 | start_positions_temp = [] 196 | end_positions_temp = [] 197 | while start_positions: 198 | start = start_positions.pop(0) 199 | start_positions_temp.append(start+1) 200 | # If no corresponding end token is found, mask till the end 201 | if not end_positions: 202 | if eos_start_positions: 203 | mask_temp[start+1:eos_start_positions] = False 204 | end_positions_temp.append(eos_start_positions-1) 205 | else: 206 | mask_temp[start+1:] = False 207 | end_positions_temp.append(input_ids.size(1)-1) 208 | break 209 | else: 210 | end = end_positions.pop(0) 211 | end_positions_temp.append(end-1) 212 | mask_temp[start+1:end] = False 213 | mask_temp = mask_temp.unsqueeze(0).repeat(mask_temp.size(0),1) 214 | for j in range(len(start_positions_temp)): 215 | mask_temp[start_positions_temp[j]:end_positions_temp[j]+1, :] = True 216 | 217 | mask[i, :, :] = mask_temp 218 | mask_image[i, :, :] = mask_image_temp 219 | 220 | mask_image = mask_image.unsqueeze(1) 221 | #mask is for normal layers in the language model, and mask_image is for the gated layer in the language model 222 | 223 | # for batch with no image. 224 | if max_images == 0: 225 | mask_image = torch.zeros(input_ids.size(0), 1, input_ids.size(1), image_feature_length, dtype=torch.bool) 226 | 227 | return mask, mask_image 228 | 229 | # this right now has bug. fix before use 230 | def create_mask_parallel(self,input_ids, media_token_id, media_end_token_id, eos_token_id): 231 | # Create a mask filled with True values 232 | mask = torch.ones_like(input_ids, dtype=torch.bool) 233 | 234 | # Find the positions of media tokens and eos tokens 235 | start_positions = (input_ids == media_token_id).nonzero(as_tuple=True) 236 | end_positions = (input_ids == media_end_token_id).nonzero(as_tuple=True) 237 | eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True) 238 | 239 | # Initialize markers for all sequences in the batch 240 | last_start = -1 * torch.ones((input_ids.size(0),), dtype=torch.long) 241 | last_end = -1 * torch.ones((input_ids.size(0),), dtype=torch.long) 242 | last_eos = input_ids.size(1) * torch.ones((input_ids.size(0),), dtype=torch.long) 243 | 244 | # Update markers with the positions of the tokens 245 | if eos_positions[0].numel() > 0: 246 | last_eos[eos_positions[0]] = eos_positions[1] 247 | 248 | for batch_idx, seq_idx in zip(*start_positions): 249 | last_start[batch_idx] = max(last_start[batch_idx], seq_idx) 250 | 251 | for batch_idx, seq_idx in zip(*end_positions): 252 | last_end[batch_idx] = max(last_end[batch_idx], seq_idx) 253 | 254 | # Set mask values based on the positions of the tokens 255 | for i in range(input_ids.size(0)): 256 | if last_start[i] > last_end[i]: 257 | mask[i, last_start[i]:last_eos[i]] = False 258 | else: 259 | mask[i, last_start[i]:last_end[i]+1] = False 260 | 261 | return mask 262 | 263 | def __call__(self, pre_batch): 264 | texts = [sample[0] for sample in pre_batch] 265 | texts_processed = self.preprocess_text_fn(texts) 266 | 267 | 268 | imgs_all = [sample[1] for sample in pre_batch] 269 | imgs_pruned = [img for i in range(len(imgs_all)) for img in imgs_all[i][:texts_processed[-1].tolist()[i]]] #pruned based on truncated text where the token is gone 270 | try: 271 | imgs_processed = self.preprocess_image_fn(imgs_pruned) 272 | except: 273 | imgs_processed = torch.empty((0, 3, 224, 224)) 274 | mask_map, mask_image_map = self.create_mask(texts_processed[0], self.media_token_id, self.media_end_token_id, self.eos_token_id, texts_processed[2].tolist(), 64) 275 | return texts_processed, imgs_processed, mask_map, mask_image_map 276 | 277 | # Configure DataLoader 278 | data = MMDataset(data, split=split) 279 | sampler = DistributedSampler(data, num_replicas=args.world_size, rank=args.rank, shuffle=False, drop_last=True) 280 | dataloader = DataLoader(data, batch_size=args.batch_size_mmdialogue, shuffle=False, sampler=sampler, num_workers=args.workers, collate_fn= DataCollator_custom(args, preprocess_image_fn, preprocess_text_fn, split=split, tokenizer=tokenizer, mask_image_description=args.mask_image_description), drop_last=True, persistent_workers=True) 281 | # add meta-data to dataloader instance for convenience 282 | dataloader.num_samples = num_samples 283 | return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) 284 | 285 | def get_dataset_fn(dataset_type): 286 | """ 287 | Helper function to get the dataset function based on the dataset type 288 | """ 289 | if dataset_type == "mmdialogue": 290 | return get_mmdialogue 291 | # elif dataset_type == "mmdialogue_single_gpu": 292 | # return get_mmdialogue_single_gpu 293 | # elif dataset_type == "mmdialogue_single_gpu_split": 294 | # return get_mmdialogue_single_gpu_split 295 | 296 | else: 297 | raise ValueError(f"Unsupported dataset type: {dataset_type}") 298 | 299 | 300 | def get_data(args, image_processor, tokenizer, dataset_type, epoch=0, split=None): 301 | """ 302 | Interface for getting the webdatasets 303 | """ 304 | return get_dataset_fn(dataset_type)( 305 | args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer, split = split 306 | ) 307 | 308 | 309 | 310 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/eval/models/open_flamingo.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from PIL import Image 4 | import torch 5 | from einops import repeat 6 | 7 | from open_flamingo.eval.eval_model import BaseEvalModel 8 | from open_flamingo.src.factory import create_model_and_transforms 9 | from open_flamingo.eval.utils import unwrap_model, get_autocast, get_cast_dtype 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | 12 | 13 | class EvalModel(BaseEvalModel): 14 | """OpenFlamingo model evaluation. 15 | 16 | Attributes: 17 | model (nn.Module): Underlying Torch model. 18 | tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. 19 | device: Index of GPU to use, or the string "CPU" 20 | """ 21 | 22 | def __init__(self, model_args): 23 | assert ( 24 | "vision_encoder_path" in model_args 25 | and "lm_path" in model_args 26 | and "checkpoint_path" in model_args 27 | and "lm_tokenizer_path" in model_args 28 | and "cross_attn_every_n_layers" in model_args 29 | and "vision_encoder_pretrained" in model_args 30 | and "precision" in model_args 31 | ), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, vision_encoder_pretrained, and precision arguments to be specified" 32 | 33 | self.device = ( 34 | model_args["device"] 35 | if ("device" in model_args and model_args["device"] >= 0) 36 | else "cpu" 37 | ) 38 | 39 | ( 40 | self.model, 41 | self.image_processor, 42 | self.tokenizer, 43 | ) = create_model_and_transforms( 44 | model_args["vision_encoder_path"], 45 | model_args["vision_encoder_pretrained"], 46 | model_args["lm_path"], 47 | model_args["lm_tokenizer_path"], 48 | cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), 49 | ) 50 | checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) 51 | if "model_state_dict" in checkpoint: 52 | checkpoint = checkpoint["model_state_dict"] 53 | checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} 54 | self.model.load_state_dict(checkpoint, strict=False) 55 | self.model.to(self.device) 56 | self.model.eval() 57 | self.tokenizer.padding_side = "left" 58 | 59 | self.lm_name = model_args["lm_path"].split("/")[-1] 60 | 61 | # autocast 62 | self.autocast = get_autocast(model_args["precision"]) 63 | self.cast_dtype = get_cast_dtype(model_args["precision"]) 64 | 65 | def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: 66 | """ 67 | Convert images to tensors, reshape them, and stack them. 68 | Args: 69 | batch: A list of lists of images. 70 | Returns: 71 | preprocessed images (tensors) or None 72 | shape (B, T_img, F, C, H, W) 73 | None if no images in batch 74 | """ 75 | images_per_example = max(len(x) for x in batch) 76 | batch_images = None 77 | for iexample, example in enumerate(batch): 78 | for iimage, image in enumerate(example): 79 | preprocessed = self.image_processor(image) 80 | if batch_images is None: 81 | batch_images = torch.zeros( 82 | (len(batch), images_per_example, 1) + preprocessed.shape, 83 | dtype=preprocessed.dtype, 84 | ) 85 | batch_images[iexample, iimage, 0] = preprocessed 86 | if batch_images is not None: 87 | batch_images = batch_images.to( 88 | self.device, dtype=self.cast_dtype, non_blocking=True 89 | ) 90 | return batch_images 91 | 92 | def _prepare_text( 93 | self, 94 | batch: List[List[str]], 95 | padding="longest", 96 | truncation=True, 97 | max_length=2000, 98 | ): 99 | """ 100 | Tokenize the text and stack them. 101 | Args: 102 | batch: A list of lists of strings. 103 | Returns: 104 | input_ids (tensor) 105 | shape (B, T_txt) 106 | attention_mask (tensor) 107 | shape (B, T_txt) 108 | """ 109 | encodings = self.tokenizer( 110 | batch, 111 | padding=padding, 112 | truncation=truncation, 113 | return_tensors="pt", 114 | max_length=max_length, 115 | ) 116 | input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] 117 | input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True) 118 | attention_mask = attention_mask.to( 119 | self.device, dtype=self.cast_dtype, non_blocking=True 120 | ) 121 | return input_ids, attention_mask.bool() 122 | 123 | def get_outputs( 124 | self, 125 | batch_text: List[str], 126 | batch_images: List[List[Image.Image]], 127 | min_generation_length: int, 128 | max_generation_length: int, 129 | num_beams: int, 130 | length_penalty: float, 131 | ) -> List[str]: 132 | """ 133 | Get generation outputs. 134 | """ 135 | batch_images = self._prepare_images(batch_images) 136 | input_ids, attention_mask = self._prepare_text(batch_text) 137 | 138 | with torch.inference_mode(): 139 | with self.autocast(): 140 | outputs = unwrap_model(self.model).generate( 141 | batch_images, 142 | input_ids, 143 | attention_mask, 144 | min_new_tokens=min_generation_length, 145 | max_new_tokens=max_generation_length, 146 | num_beams=num_beams, 147 | length_penalty=length_penalty, 148 | ) 149 | 150 | # Extract only the new gnerated tokens 151 | outputs = outputs[:, len(input_ids[0]) :] 152 | 153 | return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 154 | 155 | def get_rank_classifications( 156 | self, 157 | batch_text: List[str], 158 | batch_images: List[List[Image.Image]], 159 | all_class_names: List[str], 160 | use_cache: bool, 161 | normalize_length: bool, 162 | ): 163 | """ 164 | Returns a (B, |all_class_names|) tensor containing the logprobs for each class name. 165 | """ 166 | batch_images = self._prepare_images(batch_images) 167 | ctx_input_ids, ctx_attention_mask = self._prepare_text(batch_text) 168 | 169 | # Cache the context 170 | if use_cache: 171 | # reserve the last token in the context for the main forward pass 172 | self.cache_media( 173 | input_ids=ctx_input_ids, 174 | vision_x=batch_images, 175 | ) 176 | precomputed = self.__call__( 177 | vision_x=None, 178 | lang_x=ctx_input_ids, 179 | attention_mask=ctx_attention_mask, 180 | clear_conditioned_layers=False, 181 | use_cache=True, 182 | ) 183 | precomputed_logits = precomputed.logits 184 | precomputed_pkvs = precomputed.past_key_values 185 | else: 186 | precomputed_pkvs = None 187 | 188 | # Loop through class names and get log-likelihoods 189 | # Note: if all classnames are one token, this code is redundant, since we could 190 | # get all logits after one pass. However, if there are multi-token classnames, 191 | # we need to loop through each classname separately. 192 | overall_probs = [] 193 | for class_name in all_class_names: 194 | # Tokenize only the class name 195 | classname_tokens = self.tokenizer( 196 | class_name, add_special_tokens=False, return_tensors="pt" 197 | )["input_ids"].to(self.device) 198 | assert classname_tokens.ndim == 2 199 | classname_tokens = repeat( 200 | classname_tokens, "b s -> (repeat b) s", repeat=len(batch_text) 201 | ) 202 | num_tokens_in_classname = classname_tokens.shape[1] 203 | 204 | # Concatenate the class name tokens 205 | if not use_cache: 206 | _lang_x = torch.cat([ctx_input_ids, classname_tokens], dim=1) 207 | _attention_mask = torch.cat( 208 | [ 209 | ctx_attention_mask, 210 | torch.ones_like(classname_tokens).bool(), 211 | ], 212 | dim=1, 213 | ) 214 | _vision_x = batch_images 215 | else: 216 | _lang_x = classname_tokens 217 | _attention_mask = None 218 | _vision_x = None 219 | 220 | # Call forward to get the logits 221 | outputs = self.__call__( 222 | vision_x=_vision_x, 223 | lang_x=_lang_x, 224 | attention_mask=_attention_mask, 225 | clear_conditioned_layers=(not use_cache), 226 | past_key_values=precomputed_pkvs, 227 | ) 228 | 229 | # Get the logits of the classname 230 | # logits shape is either (B, num_tokens_in_classname, vocab_len) with use_cache 231 | # or (B, len(_lang_x), vocab_len) without use_cache 232 | # remember that the logits at index t on dim 1 correspond to predictions for the t+1st token 233 | logits = outputs.logits 234 | if use_cache: 235 | logits = torch.cat([precomputed_logits, logits], dim=1) 236 | 237 | logprobs = torch.log_softmax(logits, dim=-1) 238 | gen_probs = logprobs[ 239 | :, -num_tokens_in_classname - 1 : -1, : 240 | ] # (B, num_tokens_in_classname, vocab_len) 241 | gen_probs = torch.gather( 242 | gen_probs, 2, classname_tokens[:, :, None] 243 | ).squeeze(-1) 244 | 245 | # Aggregate over tokens in the classname 246 | if normalize_length: 247 | class_prob = torch.mean(gen_probs, dim=1) 248 | else: 249 | class_prob = torch.sum(gen_probs, dim=1) 250 | overall_probs.append(class_prob) # (B, 1) 251 | 252 | self.uncache_media() 253 | overall_probs = torch.vstack(overall_probs).T.cpu() # shape (B, num_classes) 254 | return overall_probs 255 | 256 | def __call__( 257 | self, 258 | lang_x: torch.Tensor, 259 | vision_x: torch.Tensor, 260 | attention_mask: torch.Tensor, 261 | past_key_values: torch.Tensor = None, 262 | clear_conditioned_layers: bool = False, 263 | use_cache: bool = False, 264 | ): 265 | """ 266 | Calls the forward function of the model. 267 | Special logic to handle the case if past_key_values is not None: 268 | then lang_x is assumed to contain the tokens to be generated 269 | *excluding* the tokens already in past_key_values. 270 | We then repeatedly call forward, updating the past_key_values. 271 | """ 272 | # standard forward pass 273 | if past_key_values is None: 274 | with torch.inference_mode(): 275 | with self.autocast(): 276 | outputs = self.model( 277 | vision_x=vision_x, 278 | lang_x=lang_x, 279 | attention_mask=attention_mask, 280 | clear_conditioned_layers=clear_conditioned_layers, 281 | past_key_values=past_key_values, 282 | use_cache=use_cache, 283 | ) 284 | return outputs 285 | 286 | # loop to handle updating past_key_values 287 | logits = [] 288 | for token_idx in range(lang_x.shape[1]): 289 | _lang_x = lang_x[:, token_idx].reshape((-1, 1)) 290 | if attention_mask is not None: 291 | _attention_mask = attention_mask[:, token_idx].reshape((-1, 1)) 292 | else: 293 | _attention_mask = None 294 | 295 | with torch.inference_mode(): 296 | with self.autocast(): 297 | outputs = self.model( 298 | vision_x=vision_x, 299 | lang_x=_lang_x, 300 | attention_mask=_attention_mask, 301 | clear_conditioned_layers=False, 302 | past_key_values=past_key_values, 303 | use_cache=True, 304 | ) 305 | 306 | past_key_values = outputs.past_key_values 307 | logits.append(outputs.logits) 308 | 309 | logits = torch.cat(logits, dim=1) 310 | return CausalLMOutputWithPast( 311 | logits=logits, 312 | past_key_values=past_key_values, 313 | ) 314 | 315 | def encode_vision_x(self, image_tensor: torch.Tensor): 316 | unwrap_model(self.model)._encode_vision_x(image_tensor.to(self.device)) 317 | 318 | def uncache_media(self): 319 | unwrap_model(self.model).uncache_media() 320 | 321 | def cache_media(self, input_ids, vision_x): 322 | unwrap_model(self.model).cache_media(input_ids=input_ids, vision_x=vision_x) 323 | 324 | def get_vqa_prompt(self, question, answer=None) -> str: 325 | return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" 326 | 327 | def get_caption_prompt(self, caption=None) -> str: 328 | return f"Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" 329 | 330 | def get_imagenet_prompt(self, label=None) -> str: 331 | return f"Output:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" 332 | 333 | def get_hateful_memes_prompt(self, text, label=None) -> str: 334 | return f"is an image with: '{text}' written on it. Is it hateful? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" 335 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/train_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import suppress 3 | import torch 4 | from tqdm import tqdm 5 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 6 | from torch.distributed.fsdp import ( 7 | FullStateDictConfig, 8 | StateDictType, 9 | ) 10 | from torch.distributed.fsdp.api import FullOptimStateDictConfig 11 | import os 12 | import wandb 13 | from einops import rearrange 14 | import torchvision 15 | 16 | def get_cast_dtype(precision: str): 17 | cast_dtype = None 18 | if precision == "bf16": 19 | cast_dtype = torch.bfloat16 20 | elif precision == "fp16": 21 | cast_dtype = torch.float16 22 | return cast_dtype 23 | 24 | 25 | def get_mp_policy_dtype(precision: str): 26 | if "bfloat16" in precision or "bf16" in precision: 27 | return torch.bfloat16 28 | elif precision == "fp16": 29 | return torch.float16 30 | else: 31 | return torch.float32 32 | 33 | 34 | def get_autocast(precision, cache_enabled=True): 35 | if precision == "amp": 36 | return torch.cuda.amp.autocast(cache_enabled=cache_enabled) 37 | elif precision == "amp_bfloat16" or precision == "amp_bf16": 38 | # amp_bfloat16 is more stable than amp float16 for clip training 39 | return lambda: torch.cuda.amp.autocast( 40 | dtype=torch.bfloat16, cache_enabled=cache_enabled 41 | ) 42 | else: 43 | return suppress 44 | 45 | def train_one_epoch( 46 | args, 47 | model, 48 | epoch, 49 | mmdialogue_loader, 50 | mmdialogue_val_loader, 51 | tokenizer, 52 | optimizer, 53 | lr_scheduler, 54 | device_id, 55 | total_training_steps, 56 | num_batches_per_epoch, 57 | num_batches_per_epoch_val, 58 | wandb, 59 | csv_logger, 60 | val_csv_logger 61 | ): 62 | autocast = get_autocast( 63 | args.precision, cache_enabled=(not args.fsdp) 64 | ) # if fsdp, disable cache to save memory 65 | cast_dtype = get_cast_dtype(args.precision) 66 | 67 | # setup model 68 | media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 69 | media_end_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 70 | endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] 71 | 72 | if args.citation_module: 73 | cite_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 74 | cite_end_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 75 | 76 | model.train() 77 | 78 | # setup logging 79 | step_time_m = AverageMeter() 80 | data_time_m = AverageMeter() 81 | end = time.time() 82 | 83 | lowest_val = 1000 84 | # loop through dataloader 85 | for num_steps, (texts, images, mask_map, mask_image_map) in tqdm( 86 | enumerate(mmdialogue_loader), 87 | disable=args.rank != 0, 88 | total=num_batches_per_epoch, 89 | initial=0, 90 | ): 91 | data_time_m.update(time.time() - end) 92 | 93 | #### MMDialogue FORWARD PASS #### 94 | images = images.to(device_id, dtype=cast_dtype, non_blocking=True) 95 | images = rearrange(images, "(b t f) c h w -> b t f c h w", t=1, f=1) 96 | 97 | input_ids = texts[0].to(device_id, dtype=cast_dtype, non_blocking=True) 98 | attention_mask = texts[1].to( 99 | device_id, dtype=cast_dtype, non_blocking=True 100 | ) 101 | image_map = texts[2].tolist() 102 | 103 | # set up labels; language model is expected to handle shifting 104 | labels = input_ids.clone() 105 | labels[labels == tokenizer.pad_token_id] = -100 106 | labels = labels.to(device_id) 107 | 108 | mask_map = mask_map.to(device_id) 109 | mask_image_map = mask_image_map.to(device_id) 110 | 111 | # gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager 112 | with autocast(): 113 | loss = model( 114 | vision_x=images, 115 | lang_x=input_ids, 116 | attention_mask=attention_mask, 117 | labels=labels, 118 | image_map=image_map, 119 | mask_map=mask_map, 120 | mask_image_map=mask_image_map, 121 | )[0] 122 | 123 | loss.backward() 124 | 125 | 126 | if (not args.freeze_lm_embeddings) and ( 127 | not args.fsdp or args.fsdp_use_orig_params 128 | ): 129 | # Mask gradients for input embeddings s.t. we only update the added tokens and <|endofchunk|> 130 | if args.fsdp: 131 | embed_grad = model.lang_encoder.get_input_embeddings().weight.grad 132 | else: 133 | embed_grad = ( 134 | model.module.lang_encoder.get_input_embeddings().weight.grad 135 | ) 136 | zero_mask = torch.zeros_like(embed_grad) 137 | zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) 138 | zero_mask[endofchunk_token_id] = torch.ones_like( 139 | zero_mask[endofchunk_token_id] 140 | ) 141 | 142 | zero_mask[media_end_token_id] = torch.ones_like(zero_mask[media_end_token_id]) 143 | 144 | if args.citation_module: 145 | zero_mask[cite_token_id] = torch.ones_like(zero_mask[cite_token_id]) 146 | zero_mask[cite_end_token_id] = torch.ones_like(zero_mask[cite_end_token_id]) 147 | 148 | if args.fsdp: 149 | model.lang_encoder.get_input_embeddings().weight.grad = ( 150 | embed_grad * zero_mask 151 | ) 152 | else: 153 | model.module.lang_encoder.get_input_embeddings().weight.grad = ( 154 | embed_grad * zero_mask 155 | ) 156 | 157 | # clip gradient norm 158 | if args.fsdp: 159 | """ 160 | The way we clip gradients with FSDP is different than the non-FSDP case, 161 | because during FSDP, gradient norms are computed over certain submodules, 162 | rather than the entire model. 163 | At least for OPT-125M, this didn't seem to make a difference in performance. 164 | """ 165 | model.clip_grad_norm_(1.0) 166 | else: 167 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 168 | 169 | # step optimizer and log 170 | if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or ( 171 | num_steps == num_batches_per_epoch - 1 172 | ): 173 | optimizer.step() 174 | lr_scheduler.step() 175 | optimizer.zero_grad(set_to_none=True) 176 | 177 | # step time and reset end outside of rank 0 178 | step_time_m.update(time.time() - end) 179 | end = time.time() 180 | 181 | # rank 0 logging 182 | if args.rank == 0: 183 | global_step = num_steps + epoch * num_batches_per_epoch 184 | loss = loss.item() 185 | 186 | # log to csv 187 | csv_logger.log(epoch=epoch,step=global_step,loss=loss) 188 | 189 | if ((num_steps + 1) % 500 == 0): 190 | model.eval() 191 | mmdialogue_val_loader.set_epoch((epoch*num_batches_per_epoch)+num_steps) 192 | mmdialogue_val_loader_ = mmdialogue_val_loader.dataloader 193 | with torch.no_grad(): 194 | val_loss = valid_one_epoch_with_loss_return( 195 | args=args, 196 | model=model, 197 | epoch=epoch, 198 | step_=(epoch*num_batches_per_epoch)+num_steps, 199 | tokenizer=tokenizer, 200 | mmdialogue_loader=mmdialogue_val_loader_, 201 | device_id=device_id, 202 | num_batches_per_epoch=num_batches_per_epoch_val, 203 | csv_logger=val_csv_logger 204 | ) 205 | if val_loss <= lowest_val: 206 | lowest_val = val_loss 207 | save_checkpoint_for_resume(model, optimizer, lr_scheduler, epoch, num_steps, args) 208 | model.train() 209 | 210 | 211 | class AverageMeter(object): 212 | """Computes and stores the average and current value""" 213 | 214 | def __init__(self): 215 | self.reset() 216 | 217 | def reset(self): 218 | self.val = 0 219 | self.avg = 0 220 | self.sum = 0 221 | self.count = 0 222 | 223 | def update(self, val, n=1): 224 | self.val = val 225 | self.sum += val * n 226 | self.count += n 227 | self.avg = self.sum / self.count 228 | 229 | 230 | def filter_state_dict_to_trainable(model, state_dict): 231 | """ 232 | Remove non-trainable parameters from model state dict. 233 | Exception: Embeddings will not be removed, even if frozen. 234 | This is because we need the new <|endofchunk|> tokens to 235 | be consistent across initializations. 236 | """ 237 | for ( 238 | name, 239 | p, 240 | ) in model.named_parameters(): # won't work for fsdp + use_orig_params=False 241 | if "fsdp" in name: 242 | continue 243 | if "embed" in name or isinstance(p, torch.nn.Embedding): 244 | continue 245 | if not p.requires_grad: 246 | name = name.replace("._checkpoint_wrapped_module", "") 247 | if name in state_dict: 248 | del state_dict[name] 249 | else: 250 | print(f"WARNING: filtering but {name} not in state_dict") 251 | 252 | # also remove the keys in state_dict generated from 253 | # lang_encoder.old_decoder_blocks and lang_encoder.gated_cross_attn_layers 254 | # because these are already saved in lang_encoder.model... 255 | to_delete = [ 256 | n 257 | for n in state_dict.keys() 258 | if ("lang_encoder.old_decoder_blocks" in n) 259 | or ("lang_encoder.gated_cross_attn_layers" in n) 260 | or ("vision_encoder" in n) 261 | ] 262 | for name in to_delete: 263 | del state_dict[name] 264 | return state_dict 265 | 266 | def save_checkpoint_for_resume(model, optimizer, lr_scheduler, epoch, steps, args): 267 | """ 268 | Save training checkpoint with model, optimizer, and lr_scheduler state. 269 | """ 270 | if args.fsdp: 271 | FSDP.set_state_dict_type( 272 | model, 273 | StateDictType.FULL_STATE_DICT, 274 | FullStateDictConfig(rank0_only=True, offload_to_cpu=True), 275 | FullOptimStateDictConfig(rank0_only=True), 276 | ) 277 | model_state = model.state_dict() 278 | optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group) 279 | 280 | else: 281 | model_state = model.state_dict() 282 | optim_state = optimizer.state_dict() 283 | 284 | if args.rank == 0: 285 | if not (args.fsdp and not args.fsdp_use_orig_params): 286 | model_state = filter_state_dict_to_trainable(model, model_state) 287 | 288 | if not os.path.exists(args.run_name): 289 | os.makedirs(args.run_name) 290 | 291 | checkpoint_dict = { 292 | "epoch": epoch, 293 | "steps": steps, 294 | "model_state_dict": model_state, 295 | "optimizer_state_dict": optim_state, 296 | "lr_scheduler_state_dict": lr_scheduler.state_dict(), 297 | } 298 | 299 | print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt") 300 | torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt") 301 | 302 | def valid_one_epoch_with_loss_return( 303 | args, 304 | model, 305 | epoch, 306 | step_, 307 | mmdialogue_loader, 308 | tokenizer, 309 | device_id, 310 | num_batches_per_epoch, 311 | csv_logger, 312 | ): 313 | autocast = get_autocast( 314 | args.precision, cache_enabled=(not args.fsdp) 315 | ) # if fsdp, disable cache to save memory 316 | cast_dtype = get_cast_dtype(args.precision) 317 | 318 | # setup model 319 | media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 320 | media_end_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 321 | endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] 322 | 323 | # setup logging 324 | step_time_m = AverageMeter() 325 | data_time_m = AverageMeter() 326 | end = time.time() 327 | 328 | loss_total = 0 329 | # loop through dataloader 330 | for num_steps, (texts, images, mask_map, mask_image_map) in tqdm( 331 | enumerate(mmdialogue_loader), 332 | disable=args.rank != 0, 333 | total=num_batches_per_epoch, 334 | initial=0, 335 | ): 336 | data_time_m.update(time.time() - end) 337 | 338 | #### MMDialogue FORWARD PASS #### 339 | images = images.to(device_id, dtype=cast_dtype, non_blocking=True) 340 | images = rearrange(images, "(b t f) c h w -> b t f c h w", t=1, f=1) 341 | 342 | input_ids = texts[0].to(device_id, dtype=cast_dtype, non_blocking=True) 343 | attention_mask = texts[1].to( 344 | device_id, dtype=cast_dtype, non_blocking=True 345 | ) 346 | image_map = texts[2].tolist() 347 | 348 | # set up labels; language model is expected to handle shifting 349 | labels = input_ids.clone() 350 | labels[labels == tokenizer.pad_token_id] = -100 351 | labels[labels == tokenizer.eos_token] = -100 352 | labels = labels.to(device_id) 353 | 354 | mask_map = mask_map.to(device_id) 355 | mask_image_map = mask_image_map.to(device_id) 356 | 357 | # gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager 358 | with autocast(): 359 | loss = model( 360 | vision_x=images, 361 | lang_x=input_ids, 362 | attention_mask=attention_mask, 363 | labels=labels, 364 | image_map=image_map, 365 | mask_map=mask_map, 366 | mask_image_map=mask_image_map, 367 | )[0] 368 | 369 | 370 | loss_total+=loss.item() 371 | 372 | if args.rank == 0: 373 | csv_logger.log(epoch=epoch,step=step_,loss=loss_total/num_batches_per_epoch) 374 | 375 | return loss_total/num_batches_per_epoch 376 | 377 | def preprocess_image(sample, image_processor): 378 | """ 379 | Convert images to tensors for training. 380 | Augmentations: random horizontal flip. 381 | Normalization handled by wds. 382 | """ 383 | image = [image_processor(s).unsqueeze(0) for s in sample] 384 | image = torch.cat(image, dim=0) 385 | image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image) 386 | return image 387 | 388 | 389 | import logging 390 | def get_logger(args): 391 | logger = logging.getLogger("main") 392 | logger.setLevel(logging.INFO) 393 | formatter = logging.Formatter('%(message)s') 394 | 395 | file_handler = logging.FileHandler(os.path.join(args.test_save_path, f"{args.save_run_name}.txt")) 396 | file_handler.setFormatter(formatter) 397 | logger.addHandler(file_handler) 398 | return logger -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/src/flamingo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | from .helpers import PerceiverResampler 5 | #from open_flamingo.src.helpers import PerceiverResampler #added by HSY 10/6/23 6 | from torch.distributed.fsdp.wrap import ( 7 | enable_wrap, 8 | wrap, 9 | ) 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | from torch.distributed.fsdp import ( 12 | FullyShardedDataParallel as FSDP, 13 | ) 14 | 15 | from .utils import apply_with_stopping_condition 16 | #from open_flamingo.src.utils import apply_with_stopping_condition #added by HSY 10/6/23 17 | 18 | import ipdb #added by HSY 10/6/23 19 | 20 | 21 | class Flamingo(nn.Module): 22 | def __init__( 23 | self, 24 | vision_encoder: nn.Module, 25 | lang_encoder: nn.Module, 26 | eoc_token_id: int, 27 | media_token_id: int, 28 | vis_dim: int, 29 | cross_attn_every_n_layers: int = 1, 30 | gradient_checkpointing: bool = False, 31 | ): 32 | """ 33 | Args: 34 | vision_encoder (nn.Module): HF CLIPModel 35 | lang_encoder (nn.Module): HF causal language model 36 | eoc_token_id (int): Token id for <|endofchunk|> 37 | media_token_id (int): Token id for 38 | vis_dim (int): Dimension of the visual features. 39 | Visual features are projected to match this shape along the last dimension. 40 | cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. 41 | """ 42 | super().__init__() 43 | self.eoc_token_id = eoc_token_id 44 | self.media_token_id = media_token_id 45 | self.vis_dim = vis_dim 46 | if hasattr(lang_encoder.config, "d_model"): 47 | self.lang_dim = lang_encoder.config.d_model # mpt uses d_model 48 | else: 49 | self.lang_dim = lang_encoder.config.hidden_size 50 | 51 | self.vision_encoder = vision_encoder.visual 52 | self.perceiver = PerceiverResampler(dim=self.vis_dim) 53 | self.lang_encoder = lang_encoder 54 | self.lang_encoder.init_flamingo( 55 | media_token_id=media_token_id, 56 | lang_hidden_size=self.lang_dim, 57 | vis_hidden_size=self.vis_dim, 58 | cross_attn_every_n_layers=cross_attn_every_n_layers, 59 | gradient_checkpointing=gradient_checkpointing, 60 | ) 61 | self._use_gradient_checkpointing = gradient_checkpointing 62 | self.perceiver._use_gradient_checkpointing = gradient_checkpointing 63 | 64 | def forward( 65 | self, 66 | vision_x: torch.Tensor, 67 | lang_x: torch.Tensor, 68 | attention_mask: torch.Tensor = None, 69 | labels: torch.Tensor = None, 70 | clear_conditioned_layers: bool = True, 71 | past_key_values=None, 72 | use_cache: bool = False, 73 | image_map: list = None, #added by HSY 10/6/23 74 | ): 75 | """ 76 | Forward pass of Flamingo. 77 | 78 | Args: 79 | vision_x (torch.Tensor): Vision input 80 | shape (B, T_img, F, C, H, W) with F=1 81 | lang_x (torch.Tensor): Language input ids 82 | shape (B, T_txt) 83 | attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. 84 | labels (torch.Tensor, optional): Labels. Defaults to None. 85 | clear_conditioned_layers: if True, clear the conditioned layers 86 | once the foward pass is completed. Set this to false if the 87 | same set of images will be reused in another subsequent 88 | forward pass. 89 | past_key_values: pre-computed values to pass to language model. 90 | See past_key_values documentation in Hugging Face 91 | CausalLM models. 92 | use_cache: whether to use cached key values. See use_cache 93 | documentation in Hugging Face CausalLM models. 94 | """ 95 | assert ( 96 | self.lang_encoder.initialized_flamingo 97 | ), "Flamingo layers are not initialized. Please call `init_flamingo` first." 98 | 99 | assert ( 100 | self.lang_encoder._use_cached_vision_x or vision_x is not None 101 | ), "Must provide either vision_x or have precached media using cache_media()." 102 | 103 | ipdb.set_trace() 104 | if self.lang_encoder._use_cached_vision_x: 105 | # Case: use cached; vision_x should be cached and other 106 | # vision-related inputs should not be provided. 107 | assert ( 108 | vision_x is None 109 | ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first." 110 | assert self.lang_encoder.is_conditioned() 111 | 112 | else: 113 | # Case: do not use caching (i.e. this is a standard forward pass); 114 | self._encode_vision_x(vision_x=vision_x) 115 | self._condition_media_locations(input_ids=lang_x) 116 | 117 | ipdb.set_trace() 118 | output = self.lang_encoder( 119 | input_ids=lang_x, 120 | attention_mask=attention_mask, 121 | labels=labels, 122 | past_key_values=past_key_values, 123 | use_cache=use_cache, 124 | ) 125 | 126 | if clear_conditioned_layers: 127 | self.lang_encoder.clear_conditioned_layers() 128 | 129 | return output 130 | 131 | def generate( 132 | self, 133 | vision_x: torch.Tensor, 134 | lang_x: torch.Tensor, 135 | attention_mask: torch.Tensor = None, 136 | **kwargs, 137 | ): 138 | """ 139 | Generate text conditioned on vision and language inputs. 140 | 141 | Args: 142 | vision_x (torch.Tensor): Vision input 143 | shape (B, T_img, F, C, H, W) 144 | images in the same chunk are collated along T_img, and frames are collated along F 145 | currently only F=1 is supported (single-frame videos) 146 | lang_x (torch.Tensor): Language input 147 | shape (B, T_txt) 148 | **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: 149 | max_length (int, optional): Maximum length of the output. Defaults to None. 150 | attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. 151 | num_beams (int, optional): Number of beams. Defaults to 1. 152 | max_new_tokens (int, optional): Maximum new tokens. Defaults to None. 153 | temperature (float, optional): Temperature. Defaults to 1.0. 154 | top_k (int, optional): Top k. Defaults to 50. 155 | top_p (float, optional): Top p. Defaults to 1.0. 156 | no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. 157 | length_penalty (float, optional): Length penalty. Defaults to 1.0. 158 | num_return_sequences (int, optional): Number of return sequences. Defaults to 1. 159 | do_sample (bool, optional): Do sample. Defaults to False. 160 | early_stopping (bool, optional): Early stopping. Defaults to False. 161 | Returns: 162 | torch.Tensor: lang_x with generated tokens appended to it 163 | """ 164 | num_beams = kwargs.pop("num_beams", 1) 165 | if num_beams > 1: 166 | vision_x = vision_x.repeat_interleave(num_beams, dim=0) 167 | 168 | self.lang_encoder._use_cached_vision_x = True 169 | self._encode_vision_x(vision_x=vision_x) 170 | 171 | eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) 172 | output = self.lang_encoder.generate( 173 | input_ids=lang_x, 174 | attention_mask=attention_mask, 175 | eos_token_id=eos_token_id, 176 | num_beams=num_beams, 177 | **kwargs, 178 | ) 179 | 180 | self.lang_encoder.clear_conditioned_layers() 181 | self.lang_encoder._use_cached_vision_x = False 182 | return output 183 | 184 | def _encode_vision_x(self, vision_x: torch.Tensor): 185 | """ 186 | Compute media tokens from vision input by passing it through vision encoder and conditioning language model. 187 | Args: 188 | vision_x (torch.Tensor): Vision input 189 | shape (B, T_img, F, C, H, W) 190 | Images in the same chunk are collated along T_img, and frames are collated along F 191 | Currently only F=1 is supported (single-frame videos) 192 | 193 | rearrange code based on https://github.com/dhansmair/flamingo-mini 194 | """ 195 | 196 | assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" 197 | b, T, F = vision_x.shape[:3] 198 | assert F == 1, "Only single frame supported" 199 | 200 | vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") 201 | with torch.no_grad(): 202 | vision_x = self.vision_encoder(vision_x)[1] 203 | vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) 204 | vision_x = self.perceiver(vision_x) 205 | 206 | # added by HSY 10/7/23 shape of vision_x here : (# images, 1, 64, 1024) 207 | for layer in self.lang_encoder._get_decoder_layers(): 208 | layer.condition_vis_x(vision_x) 209 | 210 | def wrap_fsdp(self, wrapper_kwargs, device_id): 211 | """ 212 | Manually wraps submodules for FSDP and move other parameters to device_id. 213 | 214 | Why manually wrap? 215 | - all parameters within the FSDP wrapper must have the same requires_grad. 216 | We have a mix of frozen and unfrozen parameters. 217 | - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors 218 | See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 219 | 220 | The rough wrapping structure is: 221 | - FlamingoModel 222 | - FSDP(FSDP(vision_encoder)) 223 | - FSDP(FSDP(perceiver)) 224 | - lang_encoder 225 | - FSDP(FSDP(input_embeddings)) 226 | - FlamingoLayers 227 | - FSDP(FSDP(gated_cross_attn_layer)) 228 | - FSDP(FSDP(decoder_layer)) 229 | - FSDP(FSDP(output_embeddings)) 230 | - other parameters 231 | 232 | Known issues: 233 | - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, 234 | train with DDP or set the --freeze_lm_embeddings flag to true. 235 | - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. 236 | Although the training curves look okay, we found that downstream performance dramatically 237 | degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). 238 | 239 | FAQs about our FSDP wrapping strategy: 240 | Why double wrap? 241 | As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook 242 | only free gathered parameters if the module is NOT FSDP root. 243 | 244 | Why unfreeze the decoder_layers? 245 | See https://github.com/pytorch/pytorch/issues/95805 246 | As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param 247 | requires_grad=True. We need the postback to fire to avoid OOM. 248 | To effectively freeze the decoder layers, we exclude them from the optimizer. 249 | 250 | What is assumed to be frozen v. unfrozen? 251 | We assume that the model is being trained under normal Flamingo settings 252 | with these lines being called in factory.py: 253 | ``` 254 | # Freeze all parameters 255 | model.requires_grad_(False) 256 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 257 | 258 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings 259 | model.perceiver.requires_grad_(True) 260 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) 261 | [optional] model.lang_encoder.get_input_embeddings().requires_grad_(True) 262 | ``` 263 | """ 264 | # unfreeze the decoder layers 265 | for block in self.lang_encoder.old_decoder_blocks: 266 | block.requires_grad_(True) 267 | 268 | # wrap in FSDP 269 | with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): 270 | self.perceiver = wrap(wrap(self.perceiver)) 271 | self.lang_encoder.old_decoder_blocks = nn.ModuleList( 272 | wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks 273 | ) 274 | self.lang_encoder.gated_cross_attn_layers = nn.ModuleList( 275 | wrap(wrap(layer)) if layer is not None else None 276 | for layer in self.lang_encoder.gated_cross_attn_layers 277 | ) 278 | self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) 279 | self.lang_encoder.set_input_embeddings( 280 | wrap(wrap(self.lang_encoder.get_input_embeddings())) 281 | ) 282 | self.lang_encoder.set_output_embeddings( 283 | wrap(wrap(self.lang_encoder.get_output_embeddings())) 284 | ) 285 | self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen 286 | 287 | # manually move non-FSDP managed parameters to device_id 288 | # these are all in lang_encoder 289 | apply_with_stopping_condition( 290 | module=self.lang_encoder, 291 | apply_fn=lambda m: m.to(device_id), 292 | apply_condition=lambda m: len(list(m.children())) == 0, 293 | stopping_condition=lambda m: isinstance(m, FSDP), 294 | ) 295 | 296 | # exclude the original decoder layers from the optimizer 297 | for block in self.lang_encoder.old_decoder_blocks: 298 | for p in block.parameters(): 299 | p.exclude_from_optimizer = True 300 | 301 | # set up clip_grad_norm_ function 302 | def clip_grad_norm_(max_norm): 303 | self.perceiver.clip_grad_norm_(max_norm) 304 | for layer in self.lang_encoder.gated_cross_attn_layers: 305 | if layer is not None: 306 | layer.clip_grad_norm_(max_norm) 307 | self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) 308 | 309 | self.clip_grad_norm_ = clip_grad_norm_ 310 | 311 | def _condition_media_locations(self, input_ids: torch.Tensor): 312 | """ 313 | Compute the media token locations from lang_x and condition the language model on these. 314 | Args: 315 | input_ids (torch.Tensor): Language input 316 | shape (B, T_txt) 317 | """ 318 | media_locations = input_ids == self.media_token_id 319 | 320 | for layer in self.lang_encoder._get_decoder_layers(): 321 | layer.condition_media_locations(media_locations) 322 | 323 | def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor): 324 | """ 325 | Pre-cache a prompt/sequence of images / text for log-likelihood evaluations. 326 | All subsequent calls to forward() will generate attending to the LAST 327 | image in vision_x. 328 | This is not meant to be used to cache things for generate(). 329 | Args: 330 | input_ids (torch.Tensor): Language input 331 | shape (B, T_txt) 332 | vision_x (torch.Tensor): Vision input 333 | shape (B, T_img, F, C, H, W) 334 | Images in the same chunk are collated along T_img, and frames are collated along F 335 | Currently only F=1 is supported (single-frame videos) 336 | """ 337 | self._encode_vision_x(vision_x=vision_x) 338 | self._condition_media_locations(input_ids=input_ids) 339 | self.lang_encoder._use_cached_vision_x = True 340 | 341 | def uncache_media(self): 342 | """ 343 | Clear all conditioning. 344 | """ 345 | self.lang_encoder.clear_conditioned_layers() 346 | self.lang_encoder._use_cached_vision_x = False 347 | -------------------------------------------------------------------------------- /open_flamingo-main/open_flamingo/train/custom_files/custom_flamingo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | from custom_files.custom_helpers import PerceiverResampler 5 | 6 | from torch.distributed.fsdp.wrap import ( 7 | enable_wrap, 8 | wrap, 9 | ) 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | from torch.distributed.fsdp import ( 12 | FullyShardedDataParallel as FSDP, 13 | ) 14 | 15 | from custom_files.custom_utils import apply_with_stopping_condition 16 | 17 | class Flamingo(nn.Module): 18 | def __init__( 19 | self, 20 | args, 21 | vision_encoder: nn.Module, 22 | lang_encoder: nn.Module, 23 | eoc_token_id: int, 24 | media_token_id: int, 25 | vis_dim: int, 26 | cross_attn_every_n_layers: int = 1, 27 | gradient_checkpointing: bool = False, 28 | ): 29 | """ 30 | Args: 31 | vision_encoder (nn.Module): HF CLIPModel 32 | lang_encoder (nn.Module): HF causal language model 33 | eoc_token_id (int): Token id for <|endofchunk|> 34 | media_token_id (int): Token id for 35 | vis_dim (int): Dimension of the visual features. 36 | Visual features are projected to match this shape along the last dimension. 37 | cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. 38 | """ 39 | super().__init__() 40 | self.eoc_token_id = eoc_token_id 41 | self.media_token_id = media_token_id 42 | self.vis_dim = vis_dim 43 | if hasattr(lang_encoder.config, "d_model"): 44 | self.lang_dim = lang_encoder.config.d_model # mpt uses d_model 45 | else: 46 | self.lang_dim = lang_encoder.config.hidden_size 47 | 48 | self.vision_encoder = vision_encoder.visual 49 | self.perceiver = PerceiverResampler(dim=self.vis_dim) 50 | self.lang_encoder = lang_encoder 51 | self.lang_encoder.init_flamingo( 52 | args, 53 | media_token_id=media_token_id, 54 | eoc_token_id=eoc_token_id, 55 | lang_hidden_size=self.lang_dim, 56 | vis_hidden_size=self.vis_dim, 57 | cross_attn_every_n_layers=cross_attn_every_n_layers, 58 | gradient_checkpointing=gradient_checkpointing, 59 | ) 60 | self._use_gradient_checkpointing = gradient_checkpointing 61 | self.perceiver._use_gradient_checkpointing = gradient_checkpointing 62 | 63 | def forward( 64 | self, 65 | vision_x: torch.Tensor, 66 | lang_x: torch.Tensor, 67 | attention_mask: torch.Tensor = None, 68 | labels: torch.Tensor = None, 69 | clear_conditioned_layers: bool = True, 70 | past_key_values=None, 71 | use_cache: bool = False, 72 | image_map: list = None, 73 | mask_map: torch.Tensor = None, 74 | mask_image_map: torch.Tensor = None, 75 | return_dict: bool = False, 76 | ): 77 | """ 78 | Forward pass of Flamingo. 79 | 80 | Args: 81 | vision_x (torch.Tensor): Vision input 82 | shape (B, T_img, F, C, H, W) with F=1 83 | lang_x (torch.Tensor): Language input ids 84 | shape (B, T_txt) 85 | attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. 86 | labels (torch.Tensor, optional): Labels. Defaults to None. 87 | clear_conditioned_layers: if True, clear the conditioned layers 88 | once the foward pass is completed. Set this to false if the 89 | same set of images will be reused in another subsequent 90 | forward pass. 91 | past_key_values: pre-computed values to pass to language model. 92 | See past_key_values documentation in Hugging Face 93 | CausalLM models. 94 | use_cache: whether to use cached key values. See use_cache 95 | documentation in Hugging Face CausalLM models. 96 | """ 97 | if sum(image_map) != vision_x.shape[0]: 98 | import ipdb; ipdb.set_trace() 99 | 100 | assert sum(image_map) == vision_x.shape[0], "image_map is not correct" 101 | #---------------------- 102 | 103 | assert ( 104 | self.lang_encoder.initialized_flamingo 105 | ), "Flamingo layers are not initialized. Please call `init_flamingo` first." 106 | 107 | assert ( 108 | self.lang_encoder._use_cached_vision_x or vision_x is not None 109 | ), "Must provide either vision_x or have precached media using cache_media()." 110 | 111 | if self.lang_encoder._use_cached_vision_x: 112 | # Case: use cached; vision_x should be cached and other 113 | # vision-related inputs should not be provided. 114 | assert ( 115 | vision_x is None 116 | ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first." 117 | assert self.lang_encoder.is_conditioned() 118 | 119 | else: 120 | # Case: do not use caching (i.e. this is a standard forward pass); 121 | self._encode_vision_x(vision_x=vision_x, image_map=image_map) 122 | self._condition_media_locations(input_ids=lang_x) 123 | 124 | #---------------------- 125 | output = self.lang_encoder( 126 | input_ids=lang_x, 127 | attention_mask=attention_mask, 128 | mask_map=mask_map, 129 | mask_image_map = mask_image_map, 130 | labels=labels, 131 | past_key_values=past_key_values, 132 | use_cache=use_cache, 133 | return_dict=return_dict, 134 | ) 135 | 136 | if clear_conditioned_layers: 137 | self.lang_encoder.clear_conditioned_layers() 138 | 139 | return output 140 | 141 | def generate( 142 | self, 143 | vision_x: torch.Tensor, 144 | lang_x: torch.Tensor, 145 | attention_mask: torch.Tensor = None, 146 | **kwargs, 147 | ): 148 | """ 149 | Generate text conditioned on vision and language inputs. 150 | 151 | Args: 152 | vision_x (torch.Tensor): Vision input 153 | shape (B, T_img, F, C, H, W) 154 | images in the same chunk are collated along T_img, and frames are collated along F 155 | currently only F=1 is supported (single-frame videos) 156 | lang_x (torch.Tensor): Language input 157 | shape (B, T_txt) 158 | **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: 159 | max_length (int, optional): Maximum length of the output. Defaults to None. 160 | attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. 161 | num_beams (int, optional): Number of beams. Defaults to 1. 162 | max_new_tokens (int, optional): Maximum new tokens. Defaults to None. 163 | temperature (float, optional): Temperature. Defaults to 1.0. 164 | top_k (int, optional): Top k. Defaults to 50. 165 | top_p (float, optional): Top p. Defaults to 1.0. 166 | no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. 167 | length_penalty (float, optional): Length penalty. Defaults to 1.0. 168 | num_return_sequences (int, optional): Number of return sequences. Defaults to 1. 169 | do_sample (bool, optional): Do sample. Defaults to False. 170 | early_stopping (bool, optional): Early stopping. Defaults to False. 171 | Returns: 172 | torch.Tensor: lang_x with generated tokens appended to it 173 | """ 174 | num_beams = kwargs.pop("num_beams", 1) 175 | if num_beams > 1: 176 | vision_x = vision_x.repeat_interleave(num_beams, dim=0) 177 | 178 | self.lang_encoder._use_cached_vision_x = True 179 | self._encode_vision_x(vision_x=vision_x) 180 | 181 | eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) 182 | output = self.lang_encoder.generate( 183 | input_ids=lang_x, 184 | attention_mask=attention_mask, 185 | eos_token_id=eos_token_id, 186 | num_beams=num_beams, 187 | **kwargs, 188 | ) 189 | 190 | self.lang_encoder.clear_conditioned_layers() 191 | self.lang_encoder._use_cached_vision_x = False 192 | return output 193 | 194 | def _encode_vision_x(self, vision_x: torch.Tensor, 195 | image_map=None): 196 | """ 197 | Compute media tokens from vision input by passing it through vision encoder and conditioning language model. 198 | Args: 199 | vision_x (torch.Tensor): Vision input 200 | shape (B, T_img, F, C, H, W) 201 | Images in the same chunk are collated along T_img, and frames are collated along F 202 | Currently only F=1 is supported (single-frame videos) 203 | 204 | rearrange code based on https://github.com/dhansmair/flamingo-mini 205 | """ 206 | 207 | assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" 208 | b, T, F = vision_x.shape[:3] 209 | assert F == 1, "Only single frame supported" 210 | 211 | vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") 212 | with torch.no_grad(): 213 | if sum(image_map) != 0: 214 | vision_x = self.vision_encoder(vision_x)[1] 215 | else: 216 | #this is the case where there is no image in the batch. 217 | temp_vision_x_device = vision_x.device 218 | vision_x = None 219 | 220 | 221 | 222 | 223 | #---------------------------- 224 | if vision_x is not None: 225 | vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) 226 | vision_x = self.perceiver(vision_x) 227 | #---------------------------- 228 | 229 | 230 | # preprocess the vision_x for our need. 231 | if image_map is None: 232 | # for debugging purpose. 233 | import ipdb; ipdb.set_trace() 234 | else: 235 | def divide_into_batches(images, num_images_per_batch): 236 | # Compute the maximum number of images in a batch 237 | max_images = max(num_images_per_batch) 238 | 239 | # Divide the images into batches 240 | batched_images_list = [] 241 | start_idx = 0 242 | for num in num_images_per_batch: 243 | end_idx = start_idx + num 244 | batch = images[start_idx:end_idx].squeeze(1) 245 | 246 | # Pad the batch if necessary 247 | if num < max_images: 248 | padding_size = max_images - num 249 | pad_tensor = torch.zeros((padding_size, 64, 1024), dtype=images.dtype, device=images.device) 250 | batch = torch.cat([batch, pad_tensor], dim=0) 251 | 252 | batched_images_list.append(batch) 253 | start_idx = end_idx 254 | 255 | # Combine all the batches into a single tensor 256 | result = torch.stack(batched_images_list, dim=0) 257 | 258 | return result 259 | if vision_x is not None: 260 | vision_x = divide_into_batches(vision_x, image_map) 261 | else: 262 | vision_x = torch.zeros(len(image_map), 1, 64, 1024, dtype=torch.float32, device=temp_vision_x_device) 263 | 264 | for layer in self.lang_encoder._get_decoder_layers(): 265 | layer.condition_vis_x(vision_x) 266 | 267 | def wrap_fsdp(self, wrapper_kwargs, device_id): 268 | """ 269 | Manually wraps submodules for FSDP and move other parameters to device_id. 270 | 271 | Why manually wrap? 272 | - all parameters within the FSDP wrapper must have the same requires_grad. 273 | We have a mix of frozen and unfrozen parameters. 274 | - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors 275 | See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 276 | 277 | The rough wrapping structure is: 278 | - FlamingoModel 279 | - FSDP(FSDP(vision_encoder)) 280 | - FSDP(FSDP(perceiver)) 281 | - lang_encoder 282 | - FSDP(FSDP(input_embeddings)) 283 | - FlamingoLayers 284 | - FSDP(FSDP(gated_cross_attn_layer)) 285 | - FSDP(FSDP(decoder_layer)) 286 | - FSDP(FSDP(output_embeddings)) 287 | - other parameters 288 | 289 | Known issues: 290 | - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, 291 | train with DDP or set the --freeze_lm_embeddings flag to true. 292 | - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. 293 | Although the training curves look okay, we found that downstream performance dramatically 294 | degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). 295 | 296 | FAQs about our FSDP wrapping strategy: 297 | Why double wrap? 298 | As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook 299 | only free gathered parameters if the module is NOT FSDP root. 300 | 301 | Why unfreeze the decoder_layers? 302 | See https://github.com/pytorch/pytorch/issues/95805 303 | As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param 304 | requires_grad=True. We need the postback to fire to avoid OOM. 305 | To effectively freeze the decoder layers, we exclude them from the optimizer. 306 | 307 | What is assumed to be frozen v. unfrozen? 308 | We assume that the model is being trained under normal Flamingo settings 309 | with these lines being called in factory.py: 310 | ``` 311 | # Freeze all parameters 312 | model.requires_grad_(False) 313 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 314 | 315 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings 316 | model.perceiver.requires_grad_(True) 317 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) 318 | [optional] model.lang_encoder.get_input_embeddings().requires_grad_(True) 319 | ``` 320 | """ 321 | # unfreeze the decoder layers 322 | for block in self.lang_encoder.old_decoder_blocks: 323 | block.requires_grad_(True) 324 | 325 | # wrap in FSDP 326 | with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): 327 | self.perceiver = wrap(wrap(self.perceiver)) 328 | self.lang_encoder.old_decoder_blocks = nn.ModuleList( 329 | wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks 330 | ) 331 | self.lang_encoder.gated_cross_attn_layers = nn.ModuleList( 332 | wrap(wrap(layer)) if layer is not None else None 333 | for layer in self.lang_encoder.gated_cross_attn_layers 334 | ) 335 | self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) 336 | self.lang_encoder.set_input_embeddings( 337 | wrap(wrap(self.lang_encoder.get_input_embeddings())) 338 | ) 339 | self.lang_encoder.set_output_embeddings( 340 | wrap(wrap(self.lang_encoder.get_output_embeddings())) 341 | ) 342 | self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen 343 | 344 | # manually move non-FSDP managed parameters to device_id 345 | # these are all in lang_encoder 346 | apply_with_stopping_condition( 347 | module=self.lang_encoder, 348 | apply_fn=lambda m: m.to(device_id), 349 | apply_condition=lambda m: len(list(m.children())) == 0, 350 | stopping_condition=lambda m: isinstance(m, FSDP), 351 | ) 352 | 353 | # exclude the original decoder layers from the optimizer 354 | for block in self.lang_encoder.old_decoder_blocks: 355 | for p in block.parameters(): 356 | p.exclude_from_optimizer = True 357 | 358 | # set up clip_grad_norm_ function 359 | def clip_grad_norm_(max_norm): 360 | self.perceiver.clip_grad_norm_(max_norm) 361 | for layer in self.lang_encoder.gated_cross_attn_layers: 362 | if layer is not None: 363 | layer.clip_grad_norm_(max_norm) 364 | self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) 365 | 366 | self.clip_grad_norm_ = clip_grad_norm_ 367 | 368 | def _condition_media_locations(self, input_ids: torch.Tensor): 369 | """ 370 | Compute the media token locations from lang_x and condition the language model on these. 371 | Args: 372 | input_ids (torch.Tensor): Language input 373 | shape (B, T_txt) 374 | """ 375 | media_locations = input_ids == self.media_token_id 376 | 377 | for layer in self.lang_encoder._get_decoder_layers(): 378 | layer.condition_media_locations(media_locations) 379 | 380 | def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor): 381 | """ 382 | Pre-cache a prompt/sequence of images / text for log-likelihood evaluations. 383 | All subsequent calls to forward() will generate attending to the LAST 384 | image in vision_x. 385 | This is not meant to be used to cache things for generate(). 386 | Args: 387 | input_ids (torch.Tensor): Language input 388 | shape (B, T_txt) 389 | vision_x (torch.Tensor): Vision input 390 | shape (B, T_img, F, C, H, W) 391 | Images in the same chunk are collated along T_img, and frames are collated along F 392 | Currently only F=1 is supported (single-frame videos) 393 | """ 394 | self._encode_vision_x(vision_x=vision_x) 395 | self._condition_media_locations(input_ids=input_ids) 396 | self.lang_encoder._use_cached_vision_x = True 397 | 398 | def uncache_media(self): 399 | """ 400 | Clear all conditioning. 401 | """ 402 | self.lang_encoder.clear_conditioned_layers() 403 | self.lang_encoder._use_cached_vision_x = False 404 | --------------------------------------------------------------------------------