├── .gitignore ├── README.md ├── cirr_test_submission.py ├── config.py ├── data ├── cirr_dataset.py ├── files │ ├── laion_combined_info.json │ ├── laion_llm_info.json │ └── laion_template_info.json ├── fiq_dataset.py ├── laion_dataset_combined.py ├── laion_dataset_llm.py └── laion_dataset_template.py ├── main.py ├── model ├── BLIP │ ├── BLIP.gif │ ├── CODEOWNERS │ ├── CODE_OF_CONDUCT.md │ ├── LICENSE.txt │ ├── README.md │ ├── SECURITY.md │ ├── cog.yaml │ ├── configs │ │ ├── bert_config.json │ │ ├── caption_coco.yaml │ │ ├── med_config.json │ │ ├── nlvr.yaml │ │ ├── nocaps.yaml │ │ ├── pretrain.yaml │ │ ├── retrieval_coco.yaml │ │ ├── retrieval_flickr.yaml │ │ ├── retrieval_msrvtt.yaml │ │ └── vqa.yaml │ ├── data │ │ ├── __init__.py │ │ ├── coco_karpathy_dataset.py │ │ ├── flickr30k_dataset.py │ │ ├── nlvr_dataset.py │ │ ├── nocaps_dataset.py │ │ ├── pretrain_dataset.py │ │ ├── utils.py │ │ ├── video_dataset.py │ │ └── vqa_dataset.py │ ├── demo.ipynb │ ├── eval_nocaps.py │ ├── eval_retrieval_video.py │ ├── models │ │ ├── __init__.py │ │ ├── blip.py │ │ ├── blip_itm.py │ │ ├── blip_nlvr.py │ │ ├── blip_pretrain.py │ │ ├── blip_retrieval.py │ │ ├── blip_vqa.py │ │ ├── med.py │ │ ├── nlvr_encoder.py │ │ └── vit.py │ ├── predict.py │ ├── pretrain.py │ ├── requirements.txt │ ├── train_caption.py │ ├── train_nlvr.py │ ├── train_retrieval.py │ ├── train_vqa.py │ ├── transform │ │ └── randaugment.py │ └── utils.py ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py └── model.py ├── requirements.txt ├── trainer.py ├── transform.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-shot Composed Text-Image Retrieval 2 | 3 | This repository contains the official Pytorch implementation of TransAgg: [https://arxiv.org/abs/2306.07272](https://arxiv.org/abs/2306.07272) 4 | 5 | ## Environment 6 | Create the environment for running our code as follow: 7 | 8 | ``` 9 | conda create --name transagg python=3.9.16 10 | conda activate transagg 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Datasets 15 | 16 | **Laion-CIR-Template、Laion-CIR-LLM and Laion-CIR-Combined**: please refer to this [link](https://drive.google.com/drive/folders/1EGpylkOMj9tduUjAhTLtaX5UqjPMyN3X?usp=sharing) 17 | 18 | **FashionIQ**: Please refer to the [FashionIQ repo](https://github.com/XiaoxiaoGuo/fashion-iq) to get the datasets. 19 | 20 | **CIRR**: Please refer to the [CIRR repo](https://github.com/Cuberick-Orion/CIRR#download-cirr-dataset) for instructions. 21 | 22 | ## Model Zoo 23 | 24 | ### Pretrained Model 25 | 26 | **clip-Vit-B/32**: please refer to this [link](https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt) 27 | 28 | **clip-Vit-L/14**: please refer to this [link](https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt) 29 | 30 | **blip**: please refer to this [link](https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth) 31 | 32 | ### Checkpoints 33 | [https://drive.google.com/drive/folders/1EGpylkOMj9tduUjAhTLtaX5UqjPMyN3X?usp=sharing](https://drive.google.com/drive/folders/1EGpylkOMj9tduUjAhTLtaX5UqjPMyN3X?usp=sharing) 34 | 35 | 36 | 37 | ## Train 38 | **note that**, you can modify the relevant parameters in the `config.py` file 39 | ``` 40 | CUDA_VISIBLE_DEVICES=0 python main.py 41 | ``` 42 | 43 | ## Test CIRR Dataset 44 | **note that**, you can modify the relevant parameters in the `config.py` file 45 | ``` 46 | python cirr_test_submission.py 47 | ``` 48 | 49 | ## Citation 50 | if you use this code for your research or project, please cite: 51 | 52 | @article{liu2023zeroshot, 53 | title={Zero-shot Composed Text-Image Retrieval}, 54 | author={Yikun Liu and Jiangchao Yao and Ya Zhang and Yanfeng Wang and Weidi Xie}, 55 | year={2023}, 56 | journal={arXiv preprint arXiv:2306.07272}, 57 | } 58 | 59 | ## Star History 60 | 61 | [![Star History Chart](https://api.star-history.com/svg?repos=Code-kunkun/ZS-CIR&type=Date)](https://star-history.com/#Code-kunkun/ZS-CIR&Date) 62 | 63 | 64 | ## Acknowledgements 65 | Many thanks to the code bases from [CLIP4CIR](https://github.com/ABaldrati/CLIP4Cir), [CLIP](https://github.com/openai/CLIP), [BLIP](https://github.com/salesforce/BLIP) -------------------------------------------------------------------------------- /cirr_test_submission.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | from typing import List, Tuple 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from config import Config 10 | from model.model import TransAgg 11 | from utils import get_preprocess, extract_index_features 12 | from data.cirr_dataset import CIRRDataset 13 | 14 | 15 | def generate_cirr_test_submissions(file_name, model, preprocess, device): 16 | classic_test_dataset = CIRRDataset('test1', 'classic', preprocess) 17 | index_features, index_names, _ = extract_index_features(classic_test_dataset, model, return_local=False) 18 | relative_test_dataset = CIRRDataset('test1', 'relative', preprocess) 19 | pairid_to_predictions, pairid_to_group_predictions = generate_cirr_test_dicts(relative_test_dataset,index_features, 20 | index_names, model, device) 21 | 22 | submission = { 23 | 'version': 'rc2', 24 | 'metric': 'recall' 25 | } 26 | group_submission = { 27 | 'version': 'rc2', 28 | 'metric': 'recall_subset' 29 | } 30 | 31 | submission.update(pairid_to_predictions) 32 | group_submission.update(pairid_to_group_predictions) 33 | 34 | print(f"Saving CIRR test predictions") 35 | with open(f"./submission/recall_submission_{file_name}.json", 'w+') as file: 36 | json.dump(submission, file, sort_keys=True) 37 | 38 | with open(f"./submission/recall_subset_submission_{file_name}.json", 'w+') as file: 39 | json.dump(group_submission, file, sort_keys=True) 40 | 41 | 42 | def generate_cirr_test_dicts(relative_test_dataset, index_features, index_names, model, device): 43 | # Generate predictions 44 | predicted_features, reference_names, group_members, pairs_id = \ 45 | generate_cirr_test_predictions(relative_test_dataset, model, device) 46 | 47 | print(f"Compute CIRR prediction dicts") 48 | 49 | # Normalize the index features 50 | index_features = F.normalize(index_features, dim=-1).float() 51 | 52 | # Compute the distances and sort the results 53 | distances = 1 - predicted_features @ index_features.T 54 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 55 | sorted_index_names = np.array(index_names)[sorted_indices] 56 | 57 | # Delete the reference image from the results 58 | reference_mask = torch.tensor( 59 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names), 60 | -1)) 61 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 62 | sorted_index_names.shape[1] - 1) 63 | # Compute the subset predictions 64 | group_members = np.array(group_members) 65 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 66 | sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1) 67 | 68 | # Generate prediction dicts 69 | pairid_to_predictions = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in 70 | zip(pairs_id, sorted_index_names)} 71 | pairid_to_group_predictions = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in 72 | zip(pairs_id, sorted_group_names)} 73 | 74 | return pairid_to_predictions, pairid_to_group_predictions 75 | 76 | 77 | def generate_cirr_test_predictions(relative_test_dataset: CIRRDataset, model, device) -> Tuple[torch.tensor, List[str], List[List[str]], List[str]]: 78 | print(f"Compute CIRR test predictions") 79 | relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, 80 | num_workers=multiprocessing.cpu_count(), pin_memory=True) 81 | 82 | # Initialize pairs_id, predicted_features, group_members and reference_names 83 | pairs_id = [] 84 | predicted_features = [] 85 | group_members = [] 86 | reference_names = [] 87 | 88 | for batch_pairs_id, batch_reference_names, captions, batch_group_members, reference_images in tqdm( 89 | relative_test_loader): # Load data 90 | batch_group_members = np.array(batch_group_members).T.tolist() 91 | 92 | # Compute the predicted features 93 | with torch.no_grad(): 94 | reference_images = reference_images.to(device) 95 | batch_predicted_features = model.combine_features(reference_images, captions) 96 | predicted_features.append(batch_predicted_features / batch_predicted_features.norm(dim=-1, keepdim=True)) 97 | 98 | group_members.extend(batch_group_members) 99 | reference_names.extend(batch_reference_names) 100 | pairs_id.extend(batch_pairs_id) 101 | 102 | predicted_features = torch.cat(predicted_features, dim=0) 103 | 104 | return predicted_features, reference_names, group_members, pairs_id 105 | 106 | 107 | def main(): 108 | cfg = Config() 109 | model = TransAgg(cfg) 110 | device = cfg.device 111 | model = model.to(device) 112 | model.load_state_dict(torch.load(cfg.eval_load_path)) 113 | 114 | # model.load_state_dict({k.replace('blip_model.', 'pretrained_model.'): v for k, v in torch.load(cfg.eval_load_path).items()}) 115 | # input_dim = model.clip_model.visual.input_resolution 116 | if cfg.model_name.startswith("blip"): 117 | input_dim = 384 118 | elif cfg.model_name.startswith("clip"): 119 | input_dim = model.pretrained_model.visual.input_resolution 120 | 121 | preprocess = get_preprocess(cfg, model, input_dim=input_dim) 122 | 123 | model.eval() 124 | 125 | generate_cirr_test_submissions(cfg.submission_name, model, preprocess, device=device) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | 4 | @dataclass 5 | class Config: 6 | dropout: float = 0.5 7 | num_layers: int = 2 8 | model_name: str = "blip" # [blip, clip-Vit-B/32, clip-Vit-L/14] 9 | device: torch.device = torch.device('cuda') 10 | batch_size: int = 64 # you can adjust it according to your GPU memory 11 | encoder: str = 'text' # ['neither', 'text', 'both'] 12 | laion_type: str = 'laion_combined' # ['laion_combined', 'laion_template', 'laion_llm'] choose different dataset 13 | transform: str = 'targetpad' 14 | target_ratio: float = 1.25 15 | learning_rate: float = 1e-4 16 | weight_decay: float = 0.05 17 | adam_epsilon: float = 1e-8 18 | num_epochs: int = 100 19 | save_best: bool = True 20 | use_amp: bool = True 21 | validation_frequency: int = 1 22 | comment: str = "cirr_TransAgg_finetune_blip_text_combined" 23 | dataset: str='cirr' # ['fiq', 'cirr'] 24 | save_path_prefix = "/GPFS/data/yikunliu/image_retrieval_runs/wandb" 25 | # eval related 26 | eval_load_path: str="xxx" 27 | submission_name: str='cirr_test_finetune_blip_text_combined' 28 | 29 | -------------------------------------------------------------------------------- /data/cirr_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import PIL 4 | from PIL import Image 5 | Image.MAX_IMAGE_PIXELS = 2300000000 6 | 7 | 8 | class CIRRDataset(Dataset): 9 | """ 10 | CIRR dataset class which manage CIRR data 11 | The dataset can be used in 'relative' or 'classic' mode: 12 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 13 | - In 'relative' mode the dataset yield tuples made of: 14 | - (reference_image, target_image, rel_caption) when split == train 15 | - (reference_name, target_name, rel_caption, group_members) when split == val 16 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 17 | """ 18 | 19 | def __init__(self, split: str, mode: str, preprocess: callable): 20 | """ 21 | :param split: dataset split, should be in ['test', 'train', 'val'] 22 | :param mode: dataset mode, should be in ['relative', 'classic']: 23 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 24 | - In 'relative' mode the dataset yield tuples made of: 25 | - (reference_image, target_image, rel_caption) when split == train 26 | - (reference_name, target_name, rel_caption, group_members) when split == val 27 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 28 | :param preprocess: function which preprocesses the image 29 | """ 30 | self.cirr_path_prefix = "/GPFS/data/yikunliu" 31 | self.preprocess = preprocess 32 | self.mode = mode 33 | self.split = split 34 | if self.split == 'test_train': 35 | split = 'train' 36 | 37 | if split not in ['test1', 'train', 'val']: 38 | raise ValueError("split should be in ['test1', 'train', 'val'") 39 | if mode not in ['relative', 'classic']: 40 | raise ValueError("mode should be in ['relative', 'classic']") 41 | 42 | # get triplets made by (reference_image, target_image, relative caption) 43 | with open(f'{self.cirr_path_prefix}/CIRR/cirr/captions/cap.rc2.{split}.json') as f: 44 | self.triplets = json.load(f) 45 | 46 | # get a mapping from image name to relative path 47 | with open(f'{self.cirr_path_prefix}/CIRR/cirr/image_splits/split.rc2.{split}.json') as f: 48 | self.name_to_relpath = json.load(f) 49 | 50 | print(f"CIRR {split} dataset in {mode} mode initialized") 51 | 52 | def __getitem__(self, index): 53 | try: 54 | if self.mode == 'relative': 55 | group_members = self.triplets[index]['img_set']['members'] 56 | reference_name = self.triplets[index]['reference'] 57 | rel_caption = self.triplets[index]['caption'].lower() 58 | 59 | if self.split == 'train': 60 | reference_image_path = f"{self.cirr_path_prefix}/NLVR2/images/" + self.name_to_relpath[reference_name][2:] 61 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert('RGB')) 62 | target_hard_name = self.triplets[index]['target_hard'] 63 | target_image_path = f"{self.cirr_path_prefix}/NLVR2/images/" + self.name_to_relpath[target_hard_name][2:] 64 | target_image = self.preprocess(PIL.Image.open(target_image_path).convert("RGB")) 65 | return reference_image, target_image, rel_caption 66 | 67 | elif self.split == 'val': 68 | reference_image_path = f"{self.cirr_path_prefix}/NLVR2/images/" + self.name_to_relpath[reference_name][2:] 69 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert('RGB')) 70 | target_hard_name = self.triplets[index]['target_hard'] 71 | return reference_name, target_hard_name, rel_caption, group_members, reference_image 72 | 73 | elif self.split == 'test1': 74 | reference_image_path = f"{self.cirr_path_prefix}/NLVR2/images/" + self.name_to_relpath[reference_name][2:] 75 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert('RGB')) 76 | pair_id = self.triplets[index]['pairid'] 77 | return pair_id, reference_name, rel_caption, group_members, reference_image 78 | 79 | elif self.mode == 'classic': 80 | image_name = list(self.name_to_relpath.keys())[index] 81 | image_path = f"{self.cirr_path_prefix}/NLVR2/images/" + self.name_to_relpath[image_name][2:] 82 | im = PIL.Image.open(image_path).convert("RGB") 83 | image = self.preprocess(im) 84 | return image_name, image 85 | 86 | else: 87 | raise ValueError("mode should be in ['relative', 'classic']") 88 | 89 | except Exception as e: 90 | print(f"Exception: {e}") 91 | 92 | def __len__(self): 93 | if self.mode == 'relative': 94 | return len(self.triplets) 95 | elif self.mode == 'classic': 96 | return len(self.name_to_relpath) 97 | else: 98 | raise ValueError("mode should be in ['relative', 'classic']") 99 | -------------------------------------------------------------------------------- /data/fiq_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from typing import List 3 | import json 4 | import PIL 5 | 6 | class FashionIQDataset(Dataset): 7 | """ 8 | FashionIQ dataset class which manage FashionIQ data. 9 | The dataset can be used in 'relative' or 'classic' mode: 10 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 11 | - In 'relative' mode the dataset yield tuples made of: 12 | - (reference_image, target_image, image_captions) when split == train 13 | - (reference_name, target_name, image_captions) when split == val 14 | - (reference_name, reference_image, image_captions) when split == test 15 | The dataset manage an arbitrary numbers of FashionIQ category, e.g. only dress, dress+toptee+shirt, dress+shirt... 16 | """ 17 | 18 | def __init__(self, split: str, dress_types: List[str], mode: str, preprocess: callable): 19 | """ 20 | :param split: dataset split, should be in ['test', 'train', 'val'] 21 | :param dress_types: list of fashionIQ category 22 | :param mode: dataset mode, should be in ['relative', 'classic']: 23 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 24 | - In 'relative' mode the dataset yield tuples made of: 25 | - (reference_image, target_image, image_captions) when split == train 26 | - (reference_name, target_name, image_captions) when split == val 27 | - (reference_name, reference_image, image_captions) when split == test 28 | :param preprocess: function which preprocesses the image 29 | """ 30 | self.fiq_path_prefix = "/GPFS/data/yikunliu" 31 | self.mode = mode 32 | self.dress_types = dress_types 33 | self.split = split 34 | 35 | if mode not in ['relative', 'classic']: 36 | raise ValueError("mode should be in ['relative', 'classic']") 37 | if split not in ['test', 'train', 'val']: 38 | raise ValueError("split should be in ['test', 'train', 'val']") 39 | for dress_type in dress_types: 40 | if dress_type not in ['dress', 'shirt', 'toptee']: 41 | raise ValueError("dress_type should be in ['dress', 'shirt', 'toptee']") 42 | 43 | self.preprocess = preprocess 44 | 45 | # get triplets made by (reference_image, target_image, a pair of relative captions) 46 | self.triplets: List[dict] = [] 47 | for dress_type in dress_types: 48 | with open(f'{self.fiq_path_prefix}/Fashion-IQ/fashion-iq/captions/cap.{dress_type}.{split}.json') as f: 49 | self.triplets.extend(json.load(f)) 50 | 51 | # get the image names 52 | self.image_names: list = [] 53 | for dress_type in dress_types: 54 | with open(f'{self.fiq_path_prefix}/Fashion-IQ/fashion-iq/image_splits/split.{dress_type}.{split}.json') as f: 55 | self.image_names.extend(json.load(f)) 56 | 57 | print(f"FashionIQ {split} - {dress_types} dataset in {mode} mode initialized") 58 | 59 | def __getitem__(self, index): 60 | try: 61 | if self.mode == 'relative': 62 | image_captions = self.triplets[index]['captions'] 63 | reference_name = self.triplets[index]['candidate'] 64 | 65 | if self.split == 'train': 66 | reference_image_path = f'{self.fiq_path_prefix}/Fashion-IQ/fashion-iq/images/{reference_name}.png' 67 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 68 | target_name = self.triplets[index]['target'] 69 | target_image_path = f'{self.fiq_path_prefix}/Fashion-IQ/fashion-iq/images/{target_name}.png' 70 | target_image = self.preprocess(PIL.Image.open(target_image_path)) 71 | return reference_image, target_image, image_captions 72 | 73 | elif self.split == 'val': 74 | reference_image_path = f'{self.fiq_path_prefix}/Fashion-IQ/fashion-iq/images/{reference_name}.png' 75 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 76 | target_name = self.triplets[index]['target'] 77 | return reference_name, target_name, image_captions, reference_image 78 | 79 | elif self.split == 'test': 80 | reference_image_path = f'{self.fiq_path_prefix}/Fashion-IQ/fashion-iq/images/{reference_name}.png' 81 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 82 | return reference_name, reference_image, image_captions 83 | 84 | elif self.mode == 'classic': 85 | image_name = self.image_names[index] 86 | image_path = f'{self.fiq_path_prefix}/Fashion-IQ/fashion-iq/images/{image_name}.png' 87 | image = self.preprocess(PIL.Image.open(image_path)) 88 | return image_name, image 89 | 90 | else: 91 | raise ValueError("mode should be in ['relative', 'classic']") 92 | except Exception as e: 93 | print(f"Exception: {e}") 94 | 95 | def __len__(self): 96 | if self.mode == 'relative': 97 | return len(self.triplets) 98 | elif self.mode == 'classic': 99 | return len(self.image_names) 100 | else: 101 | raise ValueError("mode should be in ['relative', 'classic']") -------------------------------------------------------------------------------- /data/laion_dataset_combined.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import PIL 4 | from PIL import Image 5 | from PIL import ImageFile 6 | import os 7 | data_file_path = os.path.dirname(__file__) 8 | Image.MAX_IMAGE_PIXELS = 2300000000 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | 12 | class LaionDataset_Combined(Dataset): 13 | def __init__(self, split: str, preprocess: callable): 14 | self.preprocess = preprocess 15 | self.split = split 16 | 17 | if split not in ['train']: 18 | raise ValueError("split should be in ['train']") 19 | 20 | self.image_path_prefix = "/GPFS/public/laion_coco_metadata_600m/images/" 21 | with open(data_file_path + "/files/laion_combined_info.json") as f: 22 | self.triplets = json.load(f) 23 | 24 | print(f"Laion {split} dataset initialized") 25 | 26 | def __getitem__(self, index): 27 | 28 | reference_image = f"{str(self.triplets[index]['ref_image_id']).zfill(7)}.png" 29 | relative_caption = self.triplets[index]['relative_cap'] 30 | target_image = f"{str(self.triplets[index]['tgt_image_id']).zfill(7)}.png" 31 | 32 | reference_image_path = self.image_path_prefix + reference_image 33 | reference_image = PIL.Image.open(reference_image_path) 34 | if reference_image.mode == 'RGB': 35 | reference_image = reference_image.convert('RGB') 36 | else: 37 | reference_image = reference_image.convert('RGBA') 38 | reference_image = self.preprocess(reference_image) 39 | target_image_path = self.image_path_prefix + target_image 40 | target_image = PIL.Image.open(target_image_path) 41 | if target_image.mode == 'RGB': 42 | target_image = target_image.convert('RGB') 43 | else: 44 | target_image = target_image.convert('RGBA') 45 | target_image = self.preprocess(target_image) 46 | return reference_image, target_image, relative_caption 47 | 48 | def __len__(self): 49 | return len(self.triplets) 50 | -------------------------------------------------------------------------------- /data/laion_dataset_llm.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import PIL 4 | from PIL import Image 5 | from PIL import ImageFile 6 | import os 7 | data_file_path = os.path.dirname(__file__) 8 | Image.MAX_IMAGE_PIXELS = 2300000000 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | 12 | class LaionDataset_LLM(Dataset): 13 | def __init__(self, split: str, preprocess: callable): 14 | self.preprocess = preprocess 15 | self.split = split 16 | 17 | if split not in ['train']: 18 | raise ValueError("split should be in ['train']") 19 | 20 | self.image_path_prefix = "/GPFS/public/laion_coco_metadata_600m/images/" 21 | with open(data_file_path + "/files/laion_llm_info.json") as f: 22 | self.image_ids_map = json.load(f) 23 | self.reference_image_ids = list(self.image_ids_map.keys()) 24 | self.values = list(self.image_ids_map.values()) 25 | 26 | print(f"Laion {split} dataset initialized") 27 | 28 | def __getitem__(self, index): 29 | 30 | # reference_image = self.triplets[index]['reference_image'] 31 | reference_image = f"{self.reference_image_ids[index].zfill(7)}.png" 32 | relative_caption = self.values[index]['relative_cap'] 33 | target_image = f"{self.values[index]['tgt_image_id'].zfill(7)}.png" 34 | 35 | reference_image_path = self.image_path_prefix + reference_image 36 | reference_image = PIL.Image.open(reference_image_path) 37 | if reference_image.mode == 'RGB': 38 | reference_image = reference_image.convert('RGB') 39 | else: 40 | reference_image = reference_image.convert('RGBA') 41 | reference_image = self.preprocess(reference_image) 42 | target_image_path = self.image_path_prefix + target_image 43 | target_image = PIL.Image.open(target_image_path) 44 | if target_image.mode == 'RGB': 45 | target_image = target_image.convert('RGB') 46 | else: 47 | target_image = target_image.convert('RGBA') 48 | target_image = self.preprocess(target_image) 49 | return reference_image, target_image, relative_caption 50 | 51 | def __len__(self): 52 | return len(self.image_ids_map) 53 | -------------------------------------------------------------------------------- /data/laion_dataset_template.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import PIL 4 | from PIL import Image 5 | from PIL import ImageFile 6 | import os 7 | data_file_path = os.path.dirname(__file__) 8 | Image.MAX_IMAGE_PIXELS = 2300000000 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | 12 | class LaionDataset_Template(Dataset): 13 | def __init__(self, split: str, preprocess: callable): 14 | self.preprocess = preprocess 15 | self.split = split 16 | 17 | if split not in ['train']: 18 | raise ValueError("split should be in ['train']") 19 | 20 | self.image_path_prefix = "/GPFS/public/laion_coco_metadata_600m/images/" 21 | with open(data_file_path + "/files/laion_template_info.json") as f: 22 | self.triplets = json.load(f) 23 | 24 | print(f"Laion {split} dataset initialized") 25 | 26 | def __getitem__(self, index): 27 | 28 | reference_image = f"{str(self.triplets[index]['ref_image_id']).zfill(7)}.png" 29 | relative_caption = self.triplets[index]['relative_cap'] 30 | target_image = f"{str(self.triplets[index]['tgt_image_id']).zfill(7)}.png" 31 | 32 | reference_image_path = self.image_path_prefix + reference_image 33 | reference_image = PIL.Image.open(reference_image_path) 34 | if reference_image.mode == 'RGB': 35 | reference_image = reference_image.convert('RGB') 36 | else: 37 | reference_image = reference_image.convert('RGBA') 38 | reference_image = self.preprocess(reference_image) 39 | target_image_path = self.image_path_prefix + target_image 40 | target_image = PIL.Image.open(target_image_path) 41 | if target_image.mode == 'RGB': 42 | target_image = target_image.convert('RGB') 43 | else: 44 | target_image = target_image.convert('RGBA') 45 | target_image = self.preprocess(target_image) 46 | return reference_image, target_image, relative_caption 47 | 48 | def __len__(self): 49 | return len(self.triplets) 50 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch.multiprocessing 2 | torch.multiprocessing.set_sharing_strategy('file_system') 3 | import wandb 4 | import multiprocessing 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch import nn 8 | import random 9 | import numpy as np 10 | from trainer import Trainer 11 | from config import Config 12 | import datetime 13 | from utils import get_model, set_grad, get_preprocess, get_laion_cirr_dataset, get_laion_fiq_dataset, extract_index_features, collate_fn, get_optimizer 14 | 15 | def setup_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | # torch.backends.cudnn.deterministic = True 21 | 22 | 23 | def main(cfg): 24 | # setup_seed(0) 25 | 26 | # get the corresponding model 27 | model = get_model(cfg) 28 | set_grad(cfg, model) 29 | model.pretrained_model.eval().float() 30 | 31 | 32 | # input_dim = combiner.blip_model.visual.input_resolution 33 | if cfg.model_name.startswith('blip'): 34 | input_dim = 384 35 | elif cfg.model_name.startswith('clip'): 36 | input_dim = model.pretrained_model.visual.input_resolution 37 | preprocess = get_preprocess(cfg, model, input_dim) 38 | 39 | if cfg.dataset == 'fiq': 40 | val_dress_types = ['dress', 'toptee', 'shirt'] 41 | relative_train_dataset, relative_val_dataset, classic_val_dataset, idx_to_dress_mapping = get_laion_fiq_dataset(preprocess, val_dress_types, cfg.laion_type) 42 | # get dataset and dataloader 43 | elif cfg.dataset == 'cirr': 44 | relative_train_dataset, relative_val_dataset, classic_val_dataset = get_laion_cirr_dataset(preprocess, cfg.laion_type) 45 | relative_train_loader = DataLoader(dataset=relative_train_dataset, batch_size=cfg.batch_size, 46 | num_workers=multiprocessing.cpu_count(), pin_memory=True, collate_fn=collate_fn, 47 | drop_last=True, shuffle=True) 48 | 49 | # When fine-tuning only the text encoder we can precompute the index features since they do not change over the epochs 50 | kwargs = {} 51 | if cfg.dataset == 'fiq': 52 | kwargs['val_index_features'] = [] 53 | kwargs['val_index_names'] = [] 54 | kwargs['val_total_index_features'] = [] 55 | kwargs['idx_to_dress_mapping'] = idx_to_dress_mapping 56 | if cfg.dataset == 'cirr' and (cfg.encoder == 'text' or cfg.encoder == 'neither'): 57 | val_index_features, val_index_names, val_total_index_features = extract_index_features(classic_val_dataset, model, return_local=False) 58 | kwargs['val_index_features'], kwargs['val_index_names'], kwargs['val_total_index_features'] = val_index_features, val_index_names, val_total_index_features 59 | elif cfg.dataset == 'fiq' and (cfg.encoder == 'text' or cfg.encoder == 'neither'): 60 | for classic_val_dataset_ in classic_val_dataset: 61 | val_index_features, val_index_names, _ = extract_index_features(classic_val_dataset_, model, return_local=False) 62 | kwargs['val_index_features'].append(val_index_features) 63 | kwargs['val_index_names'].append(val_index_names) 64 | kwargs['val_total_index_features'].append(_) 65 | 66 | # Define the optimizer, the loss and the grad scaler 67 | optimizer = get_optimizer(model, cfg) 68 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=cfg.num_epochs, eta_min=1e-2 * cfg.learning_rate, last_epoch=-1) 69 | crossentropy_criterion = nn.CrossEntropyLoss(ignore_index=-100) 70 | 71 | trainer = Trainer(cfg, model, relative_train_loader, optimizer, lr_scheduler, crossentropy_criterion, classic_val_dataset, relative_val_dataset, **kwargs) 72 | trainer.train() 73 | 74 | """ 75 | if you just want to eval 76 | (1) model.load_state_dict(torch.load(model_path)) 77 | (2) trainer.eval_cirr() or trainer.eval_fiq() 78 | """ 79 | 80 | if __name__ == '__main__': 81 | cfg = Config() 82 | now = datetime.datetime.now() 83 | current_time = now.strftime("%Y-%m-%d-%H-%M-%S") 84 | cfg.save_path = f"{cfg.save_path_prefix}/{current_time}_{cfg.comment}_best_arithmetic.pth" 85 | 86 | wandb_config = vars(cfg) 87 | 88 | wandb.init(project='ZeroShot-CIR', notes=cfg.comment, config=wandb_config, name=cfg.comment) 89 | 90 | main(cfg) 91 | 92 | wandb.finish() 93 | -------------------------------------------------------------------------------- /model/BLIP/BLIP.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Code-kunkun/ZS-CIR/5b2a48518ccaef3dc2c2edc72db0523c8053506c/model/BLIP/BLIP.gif -------------------------------------------------------------------------------- /model/BLIP/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /model/BLIP/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /model/BLIP/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /model/BLIP/README.md: -------------------------------------------------------------------------------- 1 | ## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation 2 | 3 | ## Announcement: BLIP is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications! 4 | 5 | 6 | 7 | This is the PyTorch code of the BLIP paper [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10. 8 | To install the dependencies, run
pip install -r requirements.txt
9 | 10 | Catalog: 11 | - [x] Inference demo 12 | - [x] Pre-trained and finetuned checkpoints 13 | - [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2 14 | - [x] Pre-training code 15 | - [x] Zero-shot video-text retrieval 16 | - [x] Download of bootstrapped pre-training datasets 17 | 18 | 19 | ### Inference demo: 20 | Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed). 21 | The demo includes code for: 22 | 1. Image captioning 23 | 2. Open-ended visual question answering 24 | 3. Multimodal / unimodal feature extraction 25 | 4. Image-text matching 26 | 27 | Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). 28 | 29 | Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip) 30 | 31 | ### Pre-trained checkpoints: 32 | Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L 33 | --- | :---: | :---: | :---: 34 | 14M | Download| - | - 35 | 129M | Download| Download | Download 36 | 37 | ### Finetuned checkpoints: 38 | Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L 39 | --- | :---: | :---: | :---: 40 | Image-Text Retrieval (COCO) | Download| - | Download 41 | Image-Text Retrieval (Flickr30k) | Download| - | Download 42 | Image Captioning (COCO) | - | Download| Download | 43 | VQA | Download| Download | - 44 | NLVR2 | Download| - | - 45 | 46 | 47 | ### Image-Text Retrieval: 48 | 1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly. 49 | 2. To evaluate the finetuned BLIP model on COCO, run: 50 |
python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
 51 | --config ./configs/retrieval_coco.yaml \
 52 | --output_dir output/retrieval_coco \
 53 | --evaluate
54 | 3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run: 55 |
python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
 56 | --config ./configs/retrieval_coco.yaml \
 57 | --output_dir output/retrieval_coco 
58 | 59 | ### Image-Text Captioning: 60 | 1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly. 61 | 2. To evaluate the finetuned BLIP model on COCO, run: 62 |
python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate
63 | 3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server) 64 |
python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py 
65 | 4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run: 66 |
python -m torch.distributed.run --nproc_per_node=8 train_caption.py 
67 | 68 | ### VQA: 69 | 1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml. 70 | 2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server) 71 |
python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate
72 | 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run: 73 |
python -m torch.distributed.run --nproc_per_node=16 train_vqa.py 
74 | 75 | ### NLVR2: 76 | 1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml. 77 | 2. To evaluate the finetuned BLIP model, run 78 |
python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate
79 | 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run: 80 |
python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py 
81 | 82 | ### Finetune with ViT-L: 83 | In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). Gradient checkpoint can also be activated in the config file to reduce GPU memory usage. 84 | 85 | ### Pre-train: 86 | 1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}. 87 | 2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files . 88 | 3. Pre-train the model using 8 A100 GPUs: 89 |
python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 
90 | 91 | ### Zero-shot video-text retrieval: 92 | 1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml. 93 | 2. Install [decord](https://github.com/dmlc/decord) with
pip install decord
94 | 3. To perform zero-shot evaluation, run 95 |
python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py
96 | 97 | ### Pre-training datasets download: 98 | We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}. 99 | 100 | Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L 101 | --- | :---: | :---: | :---: 102 | CC3M+CC12M+SBU | Download| Download| Download 103 | LAION115M | Download| Download| Download 104 | 105 | ### Citation 106 | If you find this code to be useful for your research, please consider citing. 107 |
108 | @inproceedings{li2022blip,
109 |       title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation}, 
110 |       author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
111 |       year={2022},
112 |       booktitle={ICML},
113 | }
114 | 115 | ### Acknowledgement 116 | The implementation of BLIP relies on resources from ALBEF, Huggingface Transformers, and timm. We thank the original authors for their open-sourcing. 117 | -------------------------------------------------------------------------------- /model/BLIP/SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /model/BLIP/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | cuda: "11.1" 4 | python_version: "3.8" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "ipython==7.30.1" 10 | - "torchvision==0.11.1" 11 | - "torch==1.10.0" 12 | - "timm==0.4.12" 13 | - "transformers==4.15.0" 14 | - "fairscale==0.4.4" 15 | - "pycocoevalcap==1.2" 16 | 17 | predict: "predict.py:Predictor" 18 | -------------------------------------------------------------------------------- /model/BLIP/configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /model/BLIP/configs/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | coco_gt_root: 'annotation/coco_gt' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 7 | 8 | # size of vit model; base or large 9 | vit: 'base' 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | batch_size: 32 13 | init_lr: 1e-5 14 | 15 | # vit: 'large' 16 | # vit_grad_ckpt: True 17 | # vit_ckpt_layer: 5 18 | # batch_size: 16 19 | # init_lr: 2e-6 20 | 21 | image_size: 384 22 | 23 | # generation configs 24 | max_length: 20 25 | min_length: 5 26 | num_beams: 3 27 | prompt: 'a picture of ' 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 5 33 | 34 | -------------------------------------------------------------------------------- /model/BLIP/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /model/BLIP/configs/nlvr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/NLVR2/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth' 6 | 7 | #size of vit model; base or large 8 | vit: 'base' 9 | batch_size_train: 16 10 | batch_size_test: 64 11 | vit_grad_ckpt: False 12 | vit_ckpt_layer: 0 13 | max_epoch: 15 14 | 15 | image_size: 384 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-5 20 | min_lr: 0 21 | 22 | -------------------------------------------------------------------------------- /model/BLIP/configs/nocaps.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/nocaps/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 6 | 7 | vit: 'base' 8 | batch_size: 32 9 | 10 | image_size: 384 11 | 12 | max_length: 20 13 | min_length: 5 14 | num_beams: 3 15 | prompt: 'a picture of ' -------------------------------------------------------------------------------- /model/BLIP/configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json', 2 | '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json', 3 | ] 4 | laion_path: '' 5 | 6 | # size of vit model; base or large 7 | vit: 'base' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 224 12 | batch_size: 75 13 | 14 | queue_size: 57600 15 | alpha: 0.4 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-4 20 | min_lr: 1e-6 21 | warmup_lr: 1e-6 22 | lr_decay_rate: 0.9 23 | max_epoch: 20 24 | warmup_steps: 3000 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /model/BLIP/configs/retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | dataset: 'coco' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 12 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 256 28 | negative_all_rank: True 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /model/BLIP/configs/retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/flickr30k/' 2 | ann_root: 'annotation' 3 | dataset: 'flickr' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 10 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 128 28 | negative_all_rank: False 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /model/BLIP/configs/retrieval_msrvtt.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 6 | 7 | # size of vit model; base or large 8 | vit: 'base' 9 | batch_size: 64 10 | k_test: 128 11 | image_size: 384 12 | num_frm_test: 8 -------------------------------------------------------------------------------- /model/BLIP/configs/vqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/ 2 | vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/ 3 | train_files: ['vqa_train','vqa_val','vg_qa'] 4 | ann_root: 'annotation' 5 | 6 | # set pretrained as a file path or an url 7 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 8 | 9 | # size of vit model; base or large 10 | vit: 'base' 11 | batch_size_train: 16 12 | batch_size_test: 32 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | init_lr: 2e-5 16 | 17 | image_size: 480 18 | 19 | k_test: 128 20 | inference: 'rank' 21 | 22 | # optimizer 23 | weight_decay: 0.05 24 | min_lr: 0 25 | max_epoch: 10 -------------------------------------------------------------------------------- /model/BLIP/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset import pretrain_dataset 12 | import sys 13 | sys.path.append("/GPFS/rhome/yikunliu/graduation_project/model/BLIP/transform") 14 | from randaugment import RandomAugment 15 | 16 | def create_dataset(dataset, config, min_scale=0.5): 17 | 18 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 19 | 20 | transform_train = transforms.Compose([ 21 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 22 | transforms.RandomHorizontalFlip(), 23 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 24 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 25 | transforms.ToTensor(), 26 | normalize, 27 | ]) 28 | transform_test = transforms.Compose([ 29 | transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC), 30 | transforms.ToTensor(), 31 | normalize, 32 | ]) 33 | 34 | if dataset=='pretrain': 35 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) 36 | return dataset 37 | 38 | elif dataset=='caption_coco': 39 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 40 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 41 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 42 | return train_dataset, val_dataset, test_dataset 43 | 44 | elif dataset=='nocaps': 45 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 46 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 47 | return val_dataset, test_dataset 48 | 49 | elif dataset=='retrieval_coco': 50 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 51 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 52 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 53 | return train_dataset, val_dataset, test_dataset 54 | 55 | elif dataset=='retrieval_flickr': 56 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 57 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 58 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 59 | return train_dataset, val_dataset, test_dataset 60 | 61 | elif dataset=='vqa': 62 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 63 | train_files = config['train_files'], split='train') 64 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 65 | return train_dataset, test_dataset 66 | 67 | elif dataset=='nlvr': 68 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 69 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 70 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 71 | return train_dataset, val_dataset, test_dataset 72 | 73 | 74 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 75 | samplers = [] 76 | for dataset,shuffle in zip(datasets,shuffles): 77 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 78 | samplers.append(sampler) 79 | return samplers 80 | 81 | 82 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 83 | loaders = [] 84 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 85 | if is_train: 86 | shuffle = (sampler is None) 87 | drop_last = True 88 | else: 89 | shuffle = False 90 | drop_last = False 91 | loader = DataLoader( 92 | dataset, 93 | batch_size=bs, 94 | num_workers=n_worker, 95 | pin_memory=True, 96 | sampler=sampler, 97 | shuffle=shuffle, 98 | collate_fn=collate_fn, 99 | drop_last=drop_last, 100 | ) 101 | loaders.append(loader) 102 | return loaders 103 | 104 | -------------------------------------------------------------------------------- /model/BLIP/data/coco_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class coco_karpathy_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. coco/images/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 18 | filename = 'coco_karpathy_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class coco_karpathy_caption_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split): 54 | ''' 55 | image_root (string): Root directory of images (e.g. coco/images/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 61 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | def __len__(self): 70 | return len(self.annotation) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.image_root,ann['image']) 77 | image = Image.open(image_path).convert('RGB') 78 | image = self.transform(image) 79 | 80 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 81 | 82 | return image, int(img_id) 83 | 84 | 85 | class coco_karpathy_retrieval_eval(Dataset): 86 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 87 | ''' 88 | image_root (string): Root directory of images (e.g. coco/images/) 89 | ann_root (string): directory to store the annotation file 90 | split (string): val or test 91 | ''' 92 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 93 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 94 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 95 | 96 | download_url(urls[split],ann_root) 97 | 98 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 99 | self.transform = transform 100 | self.image_root = image_root 101 | 102 | self.text = [] 103 | self.image = [] 104 | self.txt2img = {} 105 | self.img2txt = {} 106 | 107 | txt_id = 0 108 | for img_id, ann in enumerate(self.annotation): 109 | self.image.append(ann['image']) 110 | self.img2txt[img_id] = [] 111 | for i, caption in enumerate(ann['caption']): 112 | self.text.append(pre_caption(caption,max_words)) 113 | self.img2txt[img_id].append(txt_id) 114 | self.txt2img[txt_id] = img_id 115 | txt_id += 1 116 | 117 | def __len__(self): 118 | return len(self.annotation) 119 | 120 | def __getitem__(self, index): 121 | 122 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 123 | image = Image.open(image_path).convert('RGB') 124 | image = self.transform(image) 125 | 126 | return image, index -------------------------------------------------------------------------------- /model/BLIP/data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class flickr30k_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. flickr30k/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' 18 | filename = 'flickr30k_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class flickr30k_retrieval_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 54 | ''' 55 | image_root (string): Root directory of images (e.g. flickr30k/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} 61 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | self.text = [] 70 | self.image = [] 71 | self.txt2img = {} 72 | self.img2txt = {} 73 | 74 | txt_id = 0 75 | for img_id, ann in enumerate(self.annotation): 76 | self.image.append(ann['image']) 77 | self.img2txt[img_id] = [] 78 | for i, caption in enumerate(ann['caption']): 79 | self.text.append(pre_caption(caption,max_words)) 80 | self.img2txt[img_id].append(txt_id) 81 | self.txt2img[txt_id] = img_id 82 | txt_id += 1 83 | 84 | def __len__(self): 85 | return len(self.annotation) 86 | 87 | def __getitem__(self, index): 88 | 89 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 90 | image = Image.open(image_path).convert('RGB') 91 | image = self.transform(image) 92 | 93 | return image, index -------------------------------------------------------------------------------- /model/BLIP/data/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.utils import download_url 7 | 8 | from PIL import Image 9 | 10 | from data.utils import pre_caption 11 | 12 | class nlvr_dataset(Dataset): 13 | def __init__(self, transform, image_root, ann_root, split): 14 | ''' 15 | image_root (string): Root directory of images 16 | ann_root (string): directory to store the annotation file 17 | split (string): train, val or test 18 | ''' 19 | urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json', 20 | 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json', 21 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'} 22 | filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'} 23 | 24 | download_url(urls[split],ann_root) 25 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 26 | 27 | self.transform = transform 28 | self.image_root = image_root 29 | 30 | 31 | def __len__(self): 32 | return len(self.annotation) 33 | 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.annotation[index] 38 | 39 | image0_path = os.path.join(self.image_root,ann['images'][0]) 40 | image0 = Image.open(image0_path).convert('RGB') 41 | image0 = self.transform(image0) 42 | 43 | image1_path = os.path.join(self.image_root,ann['images'][1]) 44 | image1 = Image.open(image1_path).convert('RGB') 45 | image1 = self.transform(image1) 46 | 47 | sentence = pre_caption(ann['sentence'], 40) 48 | 49 | if ann['label']=='True': 50 | label = 1 51 | else: 52 | label = 0 53 | 54 | words = sentence.split(' ') 55 | 56 | if 'left' not in words and 'right' not in words: 57 | if random.random()<0.5: 58 | return image0, image1, sentence, label 59 | else: 60 | return image1, image0, sentence, label 61 | else: 62 | if random.random()<0.5: 63 | return image0, image1, sentence, label 64 | else: 65 | new_words = [] 66 | for word in words: 67 | if word=='left': 68 | new_words.append('right') 69 | elif word=='right': 70 | new_words.append('left') 71 | else: 72 | new_words.append(word) 73 | 74 | sentence = ' '.join(new_words) 75 | return image1, image0, sentence, label 76 | 77 | 78 | -------------------------------------------------------------------------------- /model/BLIP/data/nocaps_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | class nocaps_eval(Dataset): 10 | def __init__(self, transform, image_root, ann_root, split): 11 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json', 12 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'} 13 | filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'} 14 | 15 | download_url(urls[split],ann_root) 16 | 17 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 18 | self.transform = transform 19 | self.image_root = image_root 20 | 21 | def __len__(self): 22 | return len(self.annotation) 23 | 24 | def __getitem__(self, index): 25 | 26 | ann = self.annotation[index] 27 | 28 | image_path = os.path.join(self.image_root,ann['image']) 29 | image = Image.open(image_path).convert('RGB') 30 | image = self.transform(image) 31 | 32 | return image, int(ann['img_id']) -------------------------------------------------------------------------------- /model/BLIP/data/pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from data.utils import pre_caption 13 | import os,glob 14 | 15 | class pretrain_dataset(Dataset): 16 | def __init__(self, ann_file, laion_path, transform): 17 | 18 | self.ann_pretrain = [] 19 | for f in ann_file: 20 | print('loading '+f) 21 | ann = json.load(open(f,'r')) 22 | self.ann_pretrain += ann 23 | 24 | self.laion_path = laion_path 25 | if self.laion_path: 26 | self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) 27 | 28 | print('loading '+self.laion_files[0]) 29 | with open(self.laion_files[0],'r') as f: 30 | self.ann_laion = json.load(f) 31 | 32 | self.annotation = self.ann_pretrain + self.ann_laion 33 | else: 34 | self.annotation = self.ann_pretrain 35 | 36 | self.transform = transform 37 | 38 | 39 | def reload_laion(self, epoch): 40 | n = epoch%len(self.laion_files) 41 | print('loading '+self.laion_files[n]) 42 | with open(self.laion_files[n],'r') as f: 43 | self.ann_laion = json.load(f) 44 | 45 | self.annotation = self.ann_pretrain + self.ann_laion 46 | 47 | 48 | def __len__(self): 49 | return len(self.annotation) 50 | 51 | def __getitem__(self, index): 52 | 53 | ann = self.annotation[index] 54 | 55 | image = Image.open(ann['image']).convert('RGB') 56 | image = self.transform(image) 57 | caption = pre_caption(ann['caption'],30) 58 | 59 | return image, caption -------------------------------------------------------------------------------- /model/BLIP/data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils 9 | 10 | def pre_caption(caption,max_words=50): 11 | caption = re.sub( 12 | r"([.!\"()*#:;~])", 13 | ' ', 14 | caption.lower(), 15 | ) 16 | caption = re.sub( 17 | r"\s{2,}", 18 | ' ', 19 | caption, 20 | ) 21 | caption = caption.rstrip('\n') 22 | caption = caption.strip(' ') 23 | 24 | #truncate caption 25 | caption_words = caption.split(' ') 26 | if len(caption_words)>max_words: 27 | caption = ' '.join(caption_words[:max_words]) 28 | 29 | return caption 30 | 31 | def pre_question(question,max_ques_words=50): 32 | question = re.sub( 33 | r"([.!\"()*#:;~])", 34 | '', 35 | question.lower(), 36 | ) 37 | question = question.rstrip(' ') 38 | 39 | #truncate question 40 | question_words = question.split(' ') 41 | if len(question_words)>max_ques_words: 42 | question = ' '.join(question_words[:max_ques_words]) 43 | 44 | return question 45 | 46 | 47 | def save_result(result, result_dir, filename, remove_duplicate=''): 48 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 49 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 50 | 51 | json.dump(result,open(result_file,'w')) 52 | 53 | dist.barrier() 54 | 55 | if utils.is_main_process(): 56 | # combine results from all processes 57 | result = [] 58 | 59 | for rank in range(utils.get_world_size()): 60 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 61 | res = json.load(open(result_file,'r')) 62 | result += res 63 | 64 | if remove_duplicate: 65 | result_new = [] 66 | id_list = [] 67 | for res in result: 68 | if res[remove_duplicate] not in id_list: 69 | id_list.append(res[remove_duplicate]) 70 | result_new.append(res) 71 | result = result_new 72 | 73 | json.dump(result,open(final_result_file,'w')) 74 | print('result file saved to %s'%final_result_file) 75 | 76 | return final_result_file 77 | 78 | 79 | 80 | # from pycocotools.coco import COCO 81 | # from pycocoevalcap.eval import COCOEvalCap 82 | # from torchvision.datasets.utils import download_url 83 | 84 | # def coco_caption_eval(coco_gt_root, results_file, split): 85 | # urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 86 | # 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 87 | # filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 88 | 89 | # download_url(urls[split],coco_gt_root) 90 | # annotation_file = os.path.join(coco_gt_root,filenames[split]) 91 | 92 | # # create coco object and coco_result object 93 | # coco = COCO(annotation_file) 94 | # coco_result = coco.loadRes(results_file) 95 | 96 | # # create coco_eval object by taking coco and coco_result 97 | # coco_eval = COCOEvalCap(coco, coco_result) 98 | 99 | # # evaluate on a subset of images by setting 100 | # # coco_eval.params['image_id'] = coco_result.getImgIds() 101 | # # please remove this line when evaluating the full validation set 102 | # # coco_eval.params['image_id'] = coco_result.getImgIds() 103 | 104 | # # evaluate results 105 | # # SPICE will take a few minutes the first time, but speeds up due to caching 106 | # coco_eval.evaluate() 107 | 108 | # # print output evaluation scores 109 | # for metric, score in coco_eval.eval.items(): 110 | # print(f'{metric}: {score:.3f}') 111 | 112 | # return coco_eval -------------------------------------------------------------------------------- /model/BLIP/data/video_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision.datasets.utils import download_url 3 | 4 | from PIL import Image 5 | import torch 6 | import numpy as np 7 | import random 8 | import decord 9 | from decord import VideoReader 10 | import json 11 | import os 12 | from data.utils import pre_caption 13 | 14 | decord.bridge.set_bridge("torch") 15 | 16 | class ImageNorm(object): 17 | """Apply Normalization to Image Pixels on GPU 18 | """ 19 | def __init__(self, mean, std): 20 | self.mean = torch.tensor(mean).view(1, 3, 1, 1) 21 | self.std = torch.tensor(std).view(1, 3, 1, 1) 22 | 23 | def __call__(self, img): 24 | 25 | if torch.max(img) > 1 and self.mean.max() <= 1: 26 | img.div_(255.) 27 | return img.sub_(self.mean).div_(self.std) 28 | 29 | def load_jsonl(filename): 30 | with open(filename, "r") as f: 31 | return [json.loads(l.strip("\n")) for l in f.readlines()] 32 | 33 | 34 | class VideoDataset(Dataset): 35 | 36 | def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'): 37 | ''' 38 | image_root (string): Root directory of video 39 | ann_root (string): directory to store the annotation file 40 | ''' 41 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl' 42 | filename = 'msrvtt_test.jsonl' 43 | 44 | download_url(url,ann_root) 45 | self.annotation = load_jsonl(os.path.join(ann_root,filename)) 46 | 47 | self.num_frm = num_frm 48 | self.frm_sampling_strategy = frm_sampling_strategy 49 | self.max_img_size = max_img_size 50 | self.video_root = video_root 51 | self.video_fmt = video_fmt 52 | self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 53 | 54 | self.text = [pre_caption(ann['caption'],40) for ann in self.annotation] 55 | self.txt2video = [i for i in range(len(self.annotation))] 56 | self.video2txt = self.txt2video 57 | 58 | 59 | def __len__(self): 60 | return len(self.annotation) 61 | 62 | def __getitem__(self, index): 63 | 64 | ann = self.annotation[index] 65 | 66 | video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt) 67 | 68 | vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size) 69 | 70 | video = self.img_norm(vid_frm_array.float()) 71 | 72 | return video, ann['clip_name'] 73 | 74 | 75 | 76 | def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1): 77 | try: 78 | if not height or not width: 79 | vr = VideoReader(video_path) 80 | else: 81 | vr = VideoReader(video_path, width=width, height=height) 82 | 83 | vlen = len(vr) 84 | 85 | if start_time or end_time: 86 | assert fps > 0, 'must provide video fps if specifying start and end time.' 87 | 88 | start_idx = min(int(start_time * fps), vlen) 89 | end_idx = min(int(end_time * fps), vlen) 90 | else: 91 | start_idx, end_idx = 0, vlen 92 | 93 | if self.frm_sampling_strategy == 'uniform': 94 | frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int) 95 | elif self.frm_sampling_strategy == 'rand': 96 | frame_indices = sorted(random.sample(range(vlen), self.num_frm)) 97 | elif self.frm_sampling_strategy == 'headtail': 98 | frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2)) 99 | frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2)) 100 | frame_indices = frame_indices_head + frame_indices_tail 101 | else: 102 | raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy)) 103 | 104 | raw_sample_frms = vr.get_batch(frame_indices) 105 | except Exception as e: 106 | return None 107 | 108 | raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2) 109 | 110 | return raw_sample_frms 111 | -------------------------------------------------------------------------------- /model/BLIP/data/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from data.utils import pre_question 9 | 10 | from torchvision.datasets.utils import download_url 11 | 12 | class vqa_dataset(Dataset): 13 | def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"): 14 | self.split = split 15 | 16 | self.transform = transform 17 | self.vqa_root = vqa_root 18 | self.vg_root = vg_root 19 | 20 | if split=='train': 21 | urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json', 22 | 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json', 23 | 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'} 24 | 25 | self.annotation = [] 26 | for f in train_files: 27 | download_url(urls[f],ann_root) 28 | self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r')) 29 | else: 30 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root) 31 | self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r')) 32 | 33 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root) 34 | self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r')) 35 | 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | if ann['dataset']=='vqa': 45 | image_path = os.path.join(self.vqa_root,ann['image']) 46 | elif ann['dataset']=='vg': 47 | image_path = os.path.join(self.vg_root,ann['image']) 48 | 49 | image = Image.open(image_path).convert('RGB') 50 | image = self.transform(image) 51 | 52 | if self.split == 'test': 53 | question = pre_question(ann['question']) 54 | question_id = ann['question_id'] 55 | return image, question, question_id 56 | 57 | 58 | elif self.split=='train': 59 | 60 | question = pre_question(ann['question']) 61 | 62 | if ann['dataset']=='vqa': 63 | answer_weight = {} 64 | for answer in ann['answer']: 65 | if answer in answer_weight.keys(): 66 | answer_weight[answer] += 1/len(ann['answer']) 67 | else: 68 | answer_weight[answer] = 1/len(ann['answer']) 69 | 70 | answers = list(answer_weight.keys()) 71 | weights = list(answer_weight.values()) 72 | 73 | elif ann['dataset']=='vg': 74 | answers = [ann['answer']] 75 | weights = [0.2] 76 | 77 | return image, question, answers, weights 78 | 79 | 80 | def vqa_collate_fn(batch): 81 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 82 | for image, question, answer, weights in batch: 83 | image_list.append(image) 84 | question_list.append(question) 85 | weight_list += weights 86 | answer_list += answer 87 | n.append(len(answer)) 88 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n -------------------------------------------------------------------------------- /model/BLIP/eval_nocaps.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from data import create_dataset, create_sampler, create_loader 28 | from data.utils import save_result 29 | 30 | @torch.no_grad() 31 | def evaluate(model, data_loader, device, config): 32 | # evaluate 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | print_freq = 10 38 | 39 | result = [] 40 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 41 | 42 | image = image.to(device) 43 | 44 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 45 | min_length=config['min_length'], repetition_penalty=1.1) 46 | 47 | for caption, img_id in zip(captions, image_id): 48 | result.append({"image_id": img_id.item(), "caption": caption}) 49 | 50 | return result 51 | 52 | 53 | def main(args, config): 54 | utils.init_distributed_mode(args) 55 | 56 | device = torch.device(args.device) 57 | 58 | # fix the seed for reproducibility 59 | seed = args.seed + utils.get_rank() 60 | torch.manual_seed(seed) 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | cudnn.benchmark = True 64 | 65 | #### Dataset #### 66 | print("Creating captioning dataset") 67 | val_dataset, test_dataset = create_dataset('nocaps', config) 68 | 69 | if args.distributed: 70 | num_tasks = utils.get_world_size() 71 | global_rank = utils.get_rank() 72 | samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank) 73 | else: 74 | samplers = [None,None] 75 | 76 | val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers, 77 | batch_size=[config['batch_size']]*2,num_workers=[4,4], 78 | is_trains=[False, False], collate_fns=[None,None]) 79 | 80 | #### Model #### 81 | print("Creating model") 82 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 83 | prompt=config['prompt']) 84 | 85 | model = model.to(device) 86 | 87 | model_without_ddp = model 88 | if args.distributed: 89 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 90 | model_without_ddp = model.module 91 | 92 | val_result = evaluate(model_without_ddp, val_loader, device, config) 93 | val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id') 94 | test_result = evaluate(model_without_ddp, test_loader, device, config) 95 | test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id') 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--config', default='./configs/nocaps.yaml') 101 | parser.add_argument('--output_dir', default='output/NoCaps') 102 | parser.add_argument('--device', default='cuda') 103 | parser.add_argument('--seed', default=42, type=int) 104 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 106 | parser.add_argument('--distributed', default=True, type=bool) 107 | args = parser.parse_args() 108 | 109 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 110 | 111 | args.result_dir = os.path.join(args.output_dir, 'result') 112 | 113 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 114 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 115 | 116 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 117 | 118 | main(args, config) -------------------------------------------------------------------------------- /model/BLIP/eval_retrieval_video.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_retrieval import blip_retrieval 26 | import utils 27 | from data.video_dataset import VideoDataset 28 | 29 | 30 | @torch.no_grad() 31 | def evaluation(model, data_loader, tokenizer, device, config): 32 | # test 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | 38 | print('Computing features for evaluation...') 39 | start_time = time.time() 40 | 41 | texts = data_loader.dataset.text 42 | num_text = len(texts) 43 | text_bs = 256 44 | text_ids = [] 45 | text_embeds = [] 46 | text_atts = [] 47 | for i in range(0, num_text, text_bs): 48 | text = texts[i: min(num_text, i+text_bs)] 49 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device) 50 | text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') 51 | text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:])) 52 | text_embeds.append(text_embed) 53 | text_ids.append(text_input.input_ids) 54 | text_atts.append(text_input.attention_mask) 55 | 56 | text_embeds = torch.cat(text_embeds,dim=0) 57 | text_ids = torch.cat(text_ids,dim=0) 58 | text_atts = torch.cat(text_atts,dim=0) 59 | text_ids[:,0] = tokenizer.additional_special_tokens_ids[0] 60 | 61 | video_feats = [] 62 | video_embeds = [] 63 | for video, video_id in data_loader: 64 | 65 | B,N,C,W,H = video.size() 66 | video = video.view(-1,C,W,H) 67 | video = video.to(device,non_blocking=True) 68 | video_feat = model.visual_encoder(video) 69 | video_embed = model.vision_proj(video_feat[:,0,:]) 70 | video_embed = video_embed.view(B,N,-1).mean(dim=1) 71 | video_embed = F.normalize(video_embed,dim=-1) 72 | 73 | video_feat = video_feat.view(B,-1,video_feat.shape[-1]) 74 | video_feats.append(video_feat.cpu()) 75 | video_embeds.append(video_embed) 76 | 77 | video_feats = torch.cat(video_feats,dim=0) 78 | video_embeds = torch.cat(video_embeds,dim=0) 79 | 80 | sims_matrix = video_embeds @ text_embeds.t() 81 | score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device) 82 | 83 | num_tasks = utils.get_world_size() 84 | rank = utils.get_rank() 85 | step = sims_matrix.size(0)//num_tasks + 1 86 | start = rank*step 87 | end = min(sims_matrix.size(0),start+step) 88 | 89 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 90 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 91 | 92 | encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True) 93 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) 94 | output = model.text_encoder(text_ids[topk_idx], 95 | attention_mask = text_atts[topk_idx], 96 | encoder_hidden_states = encoder_output, 97 | encoder_attention_mask = encoder_att, 98 | return_dict = True, 99 | ) 100 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 101 | score_matrix_v2t[start+i,topk_idx] = score + topk_sim 102 | 103 | sims_matrix = sims_matrix.t() 104 | score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device) 105 | 106 | step = sims_matrix.size(0)//num_tasks + 1 107 | start = rank*step 108 | end = min(sims_matrix.size(0),start+step) 109 | 110 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 111 | 112 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 113 | encoder_output = video_feats[topk_idx].to(device,non_blocking=True) 114 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) 115 | output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1), 116 | attention_mask = text_atts[start+i].repeat(config['k_test'],1), 117 | encoder_hidden_states = encoder_output, 118 | encoder_attention_mask = encoder_att, 119 | return_dict = True, 120 | ) 121 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 122 | score_matrix_t2v[start+i,topk_idx] = score + topk_sim 123 | 124 | if args.distributed: 125 | dist.barrier() 126 | torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM) 127 | torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM) 128 | 129 | total_time = time.time() - start_time 130 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 131 | print('Evaluation time {}'.format(total_time_str)) 132 | 133 | return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy() 134 | 135 | 136 | 137 | @torch.no_grad() 138 | def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt): 139 | 140 | #Video->Text 141 | ranks = np.zeros(scores_v2t.shape[0]) 142 | for index,score in enumerate(scores_v2t): 143 | inds = np.argsort(score)[::-1] 144 | ranks[index] = np.where(inds == vid2txt[index])[0][0] 145 | 146 | # Compute metrics 147 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 148 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 149 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 150 | 151 | #Text->Video 152 | ranks = np.zeros(scores_t2v.shape[0]) 153 | 154 | for index,score in enumerate(scores_t2v): 155 | inds = np.argsort(score)[::-1] 156 | ranks[index] = np.where(inds == txt2vmg[index])[0][0] 157 | 158 | mdR = np.median(ranks+1) 159 | 160 | # Compute metrics 161 | vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 162 | vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 163 | vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 164 | 165 | tr_mean = (tr1 + tr5 + tr10) / 3 166 | vr_mean = (vr1 + vr5 + vr10) / 3 167 | r_mean = (tr_mean + vr_mean) / 2 168 | 169 | eval_result = {'txt_r1': tr1, 170 | 'txt_r5': tr5, 171 | 'txt_r10': tr10, 172 | 'txt_r_mean': tr_mean, 173 | 'vid_r1': vr1, 174 | 'vid_r5': vr5, 175 | 'vid_r10': vr10, 176 | 'vid_r_mean': vr_mean, 177 | 'vid_mdR': mdR, 178 | 'r_mean': r_mean} 179 | return eval_result 180 | 181 | 182 | 183 | 184 | def main(args, config): 185 | utils.init_distributed_mode(args) 186 | 187 | device = torch.device(args.device) 188 | 189 | # fix the seed for reproducibility 190 | seed = args.seed + utils.get_rank() 191 | torch.manual_seed(seed) 192 | np.random.seed(seed) 193 | random.seed(seed) 194 | cudnn.benchmark = True 195 | 196 | #### Dataset #### 197 | print("Creating retrieval dataset") 198 | test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'], 199 | max_img_size=config['image_size'], frm_sampling_strategy='uniform') 200 | 201 | test_loader = DataLoader( 202 | test_dataset, 203 | batch_size=config['batch_size'], 204 | num_workers=4, 205 | pin_memory=True, 206 | drop_last=False, 207 | shuffle=False, 208 | ) 209 | 210 | #### Model #### 211 | print("Creating model") 212 | model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit']) 213 | 214 | model = model.to(device) 215 | 216 | model_without_ddp = model 217 | if args.distributed: 218 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 219 | model_without_ddp = model.module 220 | 221 | score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config) 222 | 223 | if utils.is_main_process(): 224 | 225 | test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt) 226 | print(test_result) 227 | 228 | log_stats = {**{f'{k}': v for k, v in test_result.items()},} 229 | with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f: 230 | f.write(json.dumps(log_stats) + "\n") 231 | 232 | 233 | if __name__ == '__main__': 234 | parser = argparse.ArgumentParser() 235 | parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml') 236 | parser.add_argument('--output_dir', default='output/Retrieval_msrvtt') 237 | parser.add_argument('--device', default='cuda') 238 | parser.add_argument('--seed', default=42, type=int) 239 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 240 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 241 | parser.add_argument('--distributed', default=True, type=bool) 242 | args = parser.parse_args() 243 | 244 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 245 | 246 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 247 | 248 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 249 | 250 | main(args, config) -------------------------------------------------------------------------------- /model/BLIP/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Code-kunkun/ZS-CIR/5b2a48518ccaef3dc2c2edc72db0523c8053506c/model/BLIP/models/__init__.py -------------------------------------------------------------------------------- /model/BLIP/models/blip_itm.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_ITM(nn.Module): 11 | def __init__(self, 12 | med_config = '/GPFS/rhome/yikunliu/graduation_project/model/BLIP/configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | ): 19 | """ 20 | Args: 21 | med_config (str): path for the mixture of encoder-decoder model's configuration file 22 | image_size (int): input image size 23 | vit (str): model size of vision transformer 24 | """ 25 | super().__init__() 26 | 27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 28 | self.tokenizer = init_tokenizer() 29 | med_config = BertConfig.from_json_file(med_config) 30 | med_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 32 | 33 | text_width = self.text_encoder.config.hidden_size 34 | 35 | self.vision_proj = nn.Linear(vision_width, embed_dim) 36 | self.text_proj = nn.Linear(text_width, embed_dim) 37 | 38 | self.itm_head = nn.Linear(text_width, 2) 39 | 40 | 41 | def forward(self, image, caption, match_head='itm'): 42 | 43 | image_embeds = self.visual_encoder(image) 44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 45 | 46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 47 | return_tensors="pt").to(image.device) 48 | 49 | 50 | if match_head=='itm': 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = image_embeds, 54 | encoder_attention_mask = image_atts, 55 | return_dict = True, 56 | ) 57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 58 | return itm_output 59 | 60 | elif match_head=='itc': 61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 62 | return_dict = True, mode = 'text') 63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 65 | 66 | sim = image_feat @ text_feat.t() 67 | return sim 68 | 69 | 70 | def blip_itm(pretrained='',**kwargs): 71 | model = BLIP_ITM(**kwargs) 72 | if pretrained: 73 | model,msg = load_checkpoint(model,pretrained) 74 | assert(len(msg.missing_keys)==0) 75 | return model 76 | -------------------------------------------------------------------------------- /model/BLIP/models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig 2 | from models.nlvr_encoder import BertModel 3 | from models.vit import interpolate_pos_embed 4 | from models.blip import create_vit, init_tokenizer, is_url 5 | 6 | from timm.models.hub import download_cached_file 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from transformers import BertTokenizer 12 | import numpy as np 13 | 14 | class BLIP_NLVR(nn.Module): 15 | def __init__(self, 16 | med_config = 'configs/med_config.json', 17 | image_size = 480, 18 | vit = 'base', 19 | vit_grad_ckpt = False, 20 | vit_ckpt_layer = 0, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | self.cls_head = nn.Sequential( 37 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 38 | nn.ReLU(), 39 | nn.Linear(self.text_encoder.config.hidden_size, 2) 40 | ) 41 | 42 | def forward(self, image, text, targets, train=True): 43 | 44 | image_embeds = self.visual_encoder(image) 45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 46 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 47 | 48 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 49 | text.input_ids[:,0] = self.tokenizer.enc_token_id 50 | 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = [image0_embeds,image1_embeds], 54 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 55 | image_atts[image0_embeds.size(0):]], 56 | return_dict = True, 57 | ) 58 | hidden_state = output.last_hidden_state[:,0,:] 59 | prediction = self.cls_head(hidden_state) 60 | 61 | if train: 62 | loss = F.cross_entropy(prediction, targets) 63 | return loss 64 | else: 65 | return prediction 66 | 67 | def blip_nlvr(pretrained='',**kwargs): 68 | model = BLIP_NLVR(**kwargs) 69 | if pretrained: 70 | model,msg = load_checkpoint(model,pretrained) 71 | print("missing keys:") 72 | print(msg.missing_keys) 73 | return model 74 | 75 | 76 | def load_checkpoint(model,url_or_filename): 77 | if is_url(url_or_filename): 78 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 79 | checkpoint = torch.load(cached_file, map_location='cpu') 80 | elif os.path.isfile(url_or_filename): 81 | checkpoint = torch.load(url_or_filename, map_location='cpu') 82 | else: 83 | raise RuntimeError('checkpoint url or path is invalid') 84 | state_dict = checkpoint['model'] 85 | 86 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 87 | 88 | for key in list(state_dict.keys()): 89 | if 'crossattention.self.' in key: 90 | new_key0 = key.replace('self','self0') 91 | new_key1 = key.replace('self','self1') 92 | state_dict[new_key0] = state_dict[key] 93 | state_dict[new_key1] = state_dict[key] 94 | elif 'crossattention.output.dense.' in key: 95 | new_key0 = key.replace('dense','dense0') 96 | new_key1 = key.replace('dense','dense1') 97 | state_dict[new_key0] = state_dict[key] 98 | state_dict[new_key1] = state_dict[key] 99 | 100 | msg = model.load_state_dict(state_dict,strict=False) 101 | print('load checkpoint from %s'%url_or_filename) 102 | return model,msg 103 | -------------------------------------------------------------------------------- /model/BLIP/models/blip_vqa.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel, BertLMHeadModel 2 | from models.blip import create_vit, init_tokenizer, load_checkpoint 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import BertTokenizer 8 | import numpy as np 9 | 10 | class BLIP_VQA(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 480, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | ): 18 | """ 19 | Args: 20 | med_config (str): path for the mixture of encoder-decoder model's configuration file 21 | image_size (int): input image size 22 | vit (str): model size of vision transformer 23 | """ 24 | super().__init__() 25 | 26 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 27 | self.tokenizer = init_tokenizer() 28 | 29 | encoder_config = BertConfig.from_json_file(med_config) 30 | encoder_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 32 | 33 | decoder_config = BertConfig.from_json_file(med_config) 34 | self.text_decoder = BertLMHeadModel(config=decoder_config) 35 | 36 | 37 | def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): 38 | 39 | image_embeds = self.visual_encoder(image) 40 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 41 | 42 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 43 | return_tensors="pt").to(image.device) 44 | question.input_ids[:,0] = self.tokenizer.enc_token_id 45 | 46 | if train: 47 | ''' 48 | n: number of answers for each question 49 | weights: weight for each answer 50 | ''' 51 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) 52 | answer.input_ids[:,0] = self.tokenizer.bos_token_id 53 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) 54 | 55 | question_output = self.text_encoder(question.input_ids, 56 | attention_mask = question.attention_mask, 57 | encoder_hidden_states = image_embeds, 58 | encoder_attention_mask = image_atts, 59 | return_dict = True) 60 | 61 | question_states = [] 62 | question_atts = [] 63 | for b, n in enumerate(n): 64 | question_states += [question_output.last_hidden_state[b]]*n 65 | question_atts += [question.attention_mask[b]]*n 66 | question_states = torch.stack(question_states,0) 67 | question_atts = torch.stack(question_atts,0) 68 | 69 | answer_output = self.text_decoder(answer.input_ids, 70 | attention_mask = answer.attention_mask, 71 | encoder_hidden_states = question_states, 72 | encoder_attention_mask = question_atts, 73 | labels = answer_targets, 74 | return_dict = True, 75 | reduction = 'none', 76 | ) 77 | 78 | loss = weights * answer_output.loss 79 | loss = loss.sum()/image.size(0) 80 | 81 | return loss 82 | 83 | 84 | else: 85 | question_output = self.text_encoder(question.input_ids, 86 | attention_mask = question.attention_mask, 87 | encoder_hidden_states = image_embeds, 88 | encoder_attention_mask = image_atts, 89 | return_dict = True) 90 | 91 | if inference=='generate': 92 | num_beams = 3 93 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) 94 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) 95 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} 96 | 97 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) 98 | 99 | outputs = self.text_decoder.generate(input_ids=bos_ids, 100 | max_length=10, 101 | min_length=1, 102 | num_beams=num_beams, 103 | eos_token_id=self.tokenizer.sep_token_id, 104 | pad_token_id=self.tokenizer.pad_token_id, 105 | **model_kwargs) 106 | 107 | answers = [] 108 | for output in outputs: 109 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 110 | answers.append(answer) 111 | return answers 112 | 113 | elif inference=='rank': 114 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 115 | answer.input_ids, answer.attention_mask, k_test) 116 | return max_ids 117 | 118 | 119 | 120 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 121 | 122 | num_ques = question_states.size(0) 123 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token 124 | 125 | start_output = self.text_decoder(start_ids, 126 | encoder_hidden_states = question_states, 127 | encoder_attention_mask = question_atts, 128 | return_dict = True, 129 | reduction = 'none') 130 | logits = start_output.logits[:,0,:] # first token's logit 131 | 132 | # topk_probs: top-k probability 133 | # topk_ids: [num_question, k] 134 | answer_first_token = answer_ids[:,1] 135 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 136 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 137 | 138 | # answer input: [num_question*k, answer_len] 139 | input_ids = [] 140 | input_atts = [] 141 | for b, topk_id in enumerate(topk_ids): 142 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 143 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 144 | input_ids = torch.cat(input_ids,dim=0) 145 | input_atts = torch.cat(input_atts,dim=0) 146 | 147 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) 148 | 149 | # repeat encoder's output for top-k answers 150 | question_states = tile(question_states, 0, k) 151 | question_atts = tile(question_atts, 0, k) 152 | 153 | output = self.text_decoder(input_ids, 154 | attention_mask = input_atts, 155 | encoder_hidden_states = question_states, 156 | encoder_attention_mask = question_atts, 157 | labels = targets_ids, 158 | return_dict = True, 159 | reduction = 'none') 160 | 161 | log_probs_sum = -output.loss 162 | log_probs_sum = log_probs_sum.view(num_ques,k) 163 | 164 | max_topk_ids = log_probs_sum.argmax(dim=1) 165 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] 166 | 167 | return max_ids 168 | 169 | 170 | def blip_vqa(pretrained='',**kwargs): 171 | model = BLIP_VQA(**kwargs) 172 | if pretrained: 173 | model,msg = load_checkpoint(model,pretrained) 174 | # assert(len(msg.missing_keys)==0) 175 | return model 176 | 177 | 178 | def tile(x, dim, n_tile): 179 | init_dim = x.size(dim) 180 | repeat_idx = [1] * x.dim() 181 | repeat_idx[dim] = n_tile 182 | x = x.repeat(*(repeat_idx)) 183 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 184 | return torch.index_select(x, dim, order_index.to(x.device)) 185 | 186 | -------------------------------------------------------------------------------- /model/BLIP/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download the weights in ./checkpoints beforehand for fast inference 3 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth 4 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth 5 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | from PIL import Image 11 | import torch 12 | from torchvision import transforms 13 | from torchvision.transforms.functional import InterpolationMode 14 | import cog 15 | 16 | from models.blip import blip_decoder 17 | from models.blip_vqa import blip_vqa 18 | from models.blip_itm import blip_itm 19 | 20 | 21 | class Predictor(cog.Predictor): 22 | def setup(self): 23 | self.device = "cuda:0" 24 | 25 | self.models = { 26 | 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth', 27 | image_size=384, vit='base'), 28 | 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth', 29 | image_size=480, vit='base'), 30 | 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth', 31 | image_size=384, vit='base') 32 | } 33 | 34 | @cog.input( 35 | "image", 36 | type=Path, 37 | help="input image", 38 | ) 39 | @cog.input( 40 | "task", 41 | type=str, 42 | default='image_captioning', 43 | options=['image_captioning', 'visual_question_answering', 'image_text_matching'], 44 | help="Choose a task.", 45 | ) 46 | @cog.input( 47 | "question", 48 | type=str, 49 | default=None, 50 | help="Type question for the input image for visual question answering task.", 51 | ) 52 | @cog.input( 53 | "caption", 54 | type=str, 55 | default=None, 56 | help="Type caption for the input image for image text matching task.", 57 | ) 58 | def predict(self, image, task, question, caption): 59 | if task == 'visual_question_answering': 60 | assert question is not None, 'Please type a question for visual question answering task.' 61 | if task == 'image_text_matching': 62 | assert caption is not None, 'Please type a caption for mage text matching task.' 63 | 64 | im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device) 65 | model = self.models[task] 66 | model.eval() 67 | model = model.to(self.device) 68 | 69 | if task == 'image_captioning': 70 | with torch.no_grad(): 71 | caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5) 72 | return 'Caption: ' + caption[0] 73 | 74 | if task == 'visual_question_answering': 75 | with torch.no_grad(): 76 | answer = model(im, question, train=False, inference='generate') 77 | return 'Answer: ' + answer[0] 78 | 79 | # image_text_matching 80 | itm_output = model(im, caption, match_head='itm') 81 | itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1] 82 | itc_score = model(im, caption, match_head='itc') 83 | return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \ 84 | f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.' 85 | 86 | 87 | def load_image(image, image_size, device): 88 | raw_image = Image.open(str(image)).convert('RGB') 89 | 90 | w, h = raw_image.size 91 | 92 | transform = transforms.Compose([ 93 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 96 | ]) 97 | image = transform(raw_image).unsqueeze(0).to(device) 98 | return image 99 | -------------------------------------------------------------------------------- /model/BLIP/pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_pretrain import blip_pretrain 26 | import utils 27 | from utils import warmup_lr_schedule, step_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | 30 | def train(model, data_loader, optimizer, epoch, device, config): 31 | # train 32 | model.train() 33 | 34 | metric_logger = utils.MetricLogger(delimiter=" ") 35 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 36 | metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 37 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 38 | metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 39 | 40 | header = 'Train Epoch: [{}]'.format(epoch) 41 | print_freq = 50 42 | 43 | if config['laion_path']: 44 | data_loader.dataset.reload_laion(epoch) 45 | 46 | data_loader.sampler.set_epoch(epoch) 47 | 48 | for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 49 | 50 | if epoch==0: 51 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 52 | 53 | optimizer.zero_grad() 54 | 55 | image = image.to(device,non_blocking=True) 56 | 57 | # ramp up alpha in the first 2 epochs 58 | alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader))) 59 | 60 | loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha) 61 | loss = loss_ita + loss_itm + loss_lm 62 | 63 | loss.backward() 64 | optimizer.step() 65 | 66 | metric_logger.update(loss_ita=loss_ita.item()) 67 | metric_logger.update(loss_itm=loss_itm.item()) 68 | metric_logger.update(loss_lm=loss_lm.item()) 69 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 70 | 71 | 72 | # gather the stats from all processes 73 | metric_logger.synchronize_between_processes() 74 | print("Averaged stats:", metric_logger.global_avg()) 75 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 76 | 77 | 78 | def main(args, config): 79 | utils.init_distributed_mode(args) 80 | 81 | device = torch.device(args.device) 82 | 83 | # fix the seed for reproducibility 84 | seed = args.seed + utils.get_rank() 85 | torch.manual_seed(seed) 86 | np.random.seed(seed) 87 | random.seed(seed) 88 | cudnn.benchmark = True 89 | 90 | #### Dataset #### 91 | print("Creating dataset") 92 | datasets = [create_dataset('pretrain', config, min_scale=0.2)] 93 | print('number of training samples: %d'%len(datasets[0])) 94 | 95 | num_tasks = utils.get_world_size() 96 | global_rank = utils.get_rank() 97 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 98 | 99 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0] 100 | 101 | #### Model #### 102 | print("Creating model") 103 | model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 104 | vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size']) 105 | 106 | model = model.to(device) 107 | 108 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 109 | 110 | start_epoch = 0 111 | if args.checkpoint: 112 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 113 | state_dict = checkpoint['model'] 114 | model.load_state_dict(state_dict) 115 | 116 | optimizer.load_state_dict(checkpoint['optimizer']) 117 | start_epoch = checkpoint['epoch']+1 118 | print('resume checkpoint from %s'%args.checkpoint) 119 | 120 | model_without_ddp = model 121 | if args.distributed: 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 123 | model_without_ddp = model.module 124 | 125 | print("Start training") 126 | start_time = time.time() 127 | for epoch in range(start_epoch, config['max_epoch']): 128 | 129 | step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate']) 130 | 131 | train_stats = train(model, data_loader, optimizer, epoch, device, config) 132 | if utils.is_main_process(): 133 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 134 | 'epoch': epoch, 135 | } 136 | save_obj = { 137 | 'model': model_without_ddp.state_dict(), 138 | 'optimizer': optimizer.state_dict(), 139 | 'config': config, 140 | 'epoch': epoch, 141 | } 142 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 143 | 144 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 145 | f.write(json.dumps(log_stats) + "\n") 146 | 147 | dist.barrier() 148 | 149 | total_time = time.time() - start_time 150 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 151 | print('Training time {}'.format(total_time_str)) 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('--config', default='./configs/pretrain.yaml') 157 | parser.add_argument('--output_dir', default='output/Pretrain') 158 | parser.add_argument('--checkpoint', default='') 159 | parser.add_argument('--evaluate', action='store_true') 160 | parser.add_argument('--device', default='cuda') 161 | parser.add_argument('--seed', default=42, type=int) 162 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 163 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 164 | parser.add_argument('--distributed', default=True, type=bool) 165 | args = parser.parse_args() 166 | 167 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 168 | 169 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 170 | 171 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 172 | 173 | main(args, config) -------------------------------------------------------------------------------- /model/BLIP/requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | transformers==4.15.0 3 | fairscale==0.4.4 4 | pycocoevalcap 5 | -------------------------------------------------------------------------------- /model/BLIP/train_caption.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | from data.utils import save_result, coco_caption_eval 30 | 31 | def train(model, data_loader, optimizer, epoch, device): 32 | # train 33 | model.train() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 37 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 38 | header = 'Train Caption Epoch: [{}]'.format(epoch) 39 | print_freq = 50 40 | 41 | for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 42 | image = image.to(device) 43 | 44 | loss = model(image, caption) 45 | 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | metric_logger.update(loss=loss.item()) 51 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 52 | 53 | # gather the stats from all processes 54 | metric_logger.synchronize_between_processes() 55 | print("Averaged stats:", metric_logger.global_avg()) 56 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 57 | 58 | 59 | @torch.no_grad() 60 | def evaluate(model, data_loader, device, config): 61 | # evaluate 62 | model.eval() 63 | 64 | metric_logger = utils.MetricLogger(delimiter=" ") 65 | header = 'Caption generation:' 66 | print_freq = 10 67 | 68 | result = [] 69 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 70 | 71 | image = image.to(device) 72 | 73 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 74 | min_length=config['min_length']) 75 | 76 | for caption, img_id in zip(captions, image_id): 77 | result.append({"image_id": img_id.item(), "caption": caption}) 78 | 79 | return result 80 | 81 | 82 | def main(args, config): 83 | utils.init_distributed_mode(args) 84 | 85 | device = torch.device(args.device) 86 | 87 | # fix the seed for reproducibility 88 | seed = args.seed + utils.get_rank() 89 | torch.manual_seed(seed) 90 | np.random.seed(seed) 91 | random.seed(seed) 92 | cudnn.benchmark = True 93 | 94 | #### Dataset #### 95 | print("Creating captioning dataset") 96 | train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config) 97 | 98 | if args.distributed: 99 | num_tasks = utils.get_world_size() 100 | global_rank = utils.get_rank() 101 | samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank) 102 | else: 103 | samplers = [None, None, None] 104 | 105 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, 106 | batch_size=[config['batch_size']]*3,num_workers=[4,4,4], 107 | is_trains=[True, False, False], collate_fns=[None,None,None]) 108 | 109 | #### Model #### 110 | print("Creating model") 111 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 112 | vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 113 | prompt=config['prompt']) 114 | 115 | model = model.to(device) 116 | 117 | model_without_ddp = model 118 | if args.distributed: 119 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 120 | model_without_ddp = model.module 121 | 122 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 123 | 124 | best = 0 125 | best_epoch = 0 126 | 127 | print("Start training") 128 | start_time = time.time() 129 | for epoch in range(0, config['max_epoch']): 130 | if not args.evaluate: 131 | if args.distributed: 132 | train_loader.sampler.set_epoch(epoch) 133 | 134 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 135 | 136 | train_stats = train(model, train_loader, optimizer, epoch, device) 137 | 138 | val_result = evaluate(model_without_ddp, val_loader, device, config) 139 | val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id') 140 | 141 | test_result = evaluate(model_without_ddp, test_loader, device, config) 142 | test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id') 143 | 144 | if utils.is_main_process(): 145 | coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val') 146 | coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test') 147 | 148 | if args.evaluate: 149 | log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()}, 150 | **{f'test_{k}': v for k, v in coco_test.eval.items()}, 151 | } 152 | with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f: 153 | f.write(json.dumps(log_stats) + "\n") 154 | else: 155 | save_obj = { 156 | 'model': model_without_ddp.state_dict(), 157 | 'optimizer': optimizer.state_dict(), 158 | 'config': config, 159 | 'epoch': epoch, 160 | } 161 | 162 | if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best: 163 | best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] 164 | best_epoch = epoch 165 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 166 | 167 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 168 | **{f'val_{k}': v for k, v in coco_val.eval.items()}, 169 | **{f'test_{k}': v for k, v in coco_test.eval.items()}, 170 | 'epoch': epoch, 171 | 'best_epoch': best_epoch, 172 | } 173 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 174 | f.write(json.dumps(log_stats) + "\n") 175 | 176 | if args.evaluate: 177 | break 178 | dist.barrier() 179 | 180 | total_time = time.time() - start_time 181 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 182 | print('Training time {}'.format(total_time_str)) 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--config', default='./configs/caption_coco.yaml') 188 | parser.add_argument('--output_dir', default='output/Caption_coco') 189 | parser.add_argument('--evaluate', action='store_true') 190 | parser.add_argument('--device', default='cuda') 191 | parser.add_argument('--seed', default=42, type=int) 192 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 193 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 194 | parser.add_argument('--distributed', default=True, type=bool) 195 | args = parser.parse_args() 196 | 197 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 198 | 199 | args.result_dir = os.path.join(args.output_dir, 'result') 200 | 201 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 202 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 203 | 204 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 205 | 206 | main(args, config) -------------------------------------------------------------------------------- /model/BLIP/train_nlvr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | import json 18 | import pickle 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.utils.data import DataLoader 24 | import torch.backends.cudnn as cudnn 25 | import torch.distributed as dist 26 | 27 | from models.blip_nlvr import blip_nlvr 28 | 29 | import utils 30 | from utils import cosine_lr_schedule, warmup_lr_schedule 31 | from data import create_dataset, create_sampler, create_loader 32 | 33 | def train(model, data_loader, optimizer, epoch, device, config): 34 | # train 35 | model.train() 36 | 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 39 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 40 | 41 | header = 'Train Epoch: [{}]'.format(epoch) 42 | print_freq = 50 43 | step_size = 10 44 | 45 | for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 46 | 47 | images = torch.cat([image0, image1], dim=0) 48 | images, targets = images.to(device), targets.to(device) 49 | 50 | loss = model(images, text, targets=targets, train=True) 51 | 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 57 | metric_logger.update(loss=loss.item()) 58 | 59 | # gather the stats from all processes 60 | metric_logger.synchronize_between_processes() 61 | print("Averaged stats:", metric_logger.global_avg()) 62 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 63 | 64 | 65 | @torch.no_grad() 66 | def evaluate(model, data_loader, device, config): 67 | # test 68 | model.eval() 69 | 70 | metric_logger = utils.MetricLogger(delimiter=" ") 71 | 72 | header = 'Evaluation:' 73 | print_freq = 50 74 | 75 | for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header): 76 | images = torch.cat([image0, image1], dim=0) 77 | images, targets = images.to(device), targets.to(device) 78 | 79 | prediction = model(images, text, targets=targets, train=False) 80 | 81 | _, pred_class = prediction.max(1) 82 | accuracy = (targets==pred_class).sum() / targets.size(0) 83 | 84 | metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0)) 85 | 86 | # gather the stats from all processes 87 | metric_logger.synchronize_between_processes() 88 | 89 | print("Averaged stats:", metric_logger.global_avg()) 90 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 91 | 92 | 93 | 94 | def main(args, config): 95 | utils.init_distributed_mode(args) 96 | 97 | device = torch.device(args.device) 98 | 99 | # fix the seed for reproducibility 100 | seed = args.seed + utils.get_rank() 101 | torch.manual_seed(seed) 102 | np.random.seed(seed) 103 | random.seed(seed) 104 | cudnn.benchmark = True 105 | 106 | #### Dataset #### 107 | print("Creating dataset") 108 | datasets = create_dataset('nlvr', config) 109 | 110 | if args.distributed: 111 | num_tasks = utils.get_world_size() 112 | global_rank = utils.get_rank() 113 | samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank) 114 | else: 115 | samplers = [None, None, None] 116 | 117 | batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']] 118 | train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size, 119 | num_workers=[4,4,4],is_trains=[True,False,False], 120 | collate_fns=[None,None,None]) 121 | 122 | #### Model #### 123 | print("Creating model") 124 | model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'], 125 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 126 | 127 | model = model.to(device) 128 | 129 | model_without_ddp = model 130 | if args.distributed: 131 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 132 | model_without_ddp = model.module 133 | 134 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 135 | 136 | print("Start training") 137 | start_time = time.time() 138 | best = 0 139 | best_epoch = 0 140 | 141 | for epoch in range(0, config['max_epoch']): 142 | if not args.evaluate: 143 | if args.distributed: 144 | train_loader.sampler.set_epoch(epoch) 145 | 146 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 147 | 148 | train_stats = train(model, train_loader, optimizer, epoch, device, config) 149 | 150 | val_stats = evaluate(model, val_loader, device, config) 151 | test_stats = evaluate(model, test_loader, device, config) 152 | 153 | if utils.is_main_process(): 154 | if args.evaluate: 155 | log_stats = {**{f'val_{k}': v for k, v in val_stats.items()}, 156 | **{f'test_{k}': v for k, v in test_stats.items()}, 157 | } 158 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 159 | f.write(json.dumps(log_stats) + "\n") 160 | 161 | else: 162 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 163 | **{f'val_{k}': v for k, v in val_stats.items()}, 164 | **{f'test_{k}': v for k, v in test_stats.items()}, 165 | 'epoch': epoch, 166 | } 167 | 168 | if float(val_stats['acc'])>best: 169 | save_obj = { 170 | 'model': model_without_ddp.state_dict(), 171 | 'optimizer': optimizer.state_dict(), 172 | 'config': config, 173 | 'epoch': epoch, 174 | } 175 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 176 | best = float(val_stats['acc']) 177 | best_epoch = epoch 178 | 179 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 180 | f.write(json.dumps(log_stats) + "\n") 181 | if args.evaluate: 182 | break 183 | 184 | dist.barrier() 185 | 186 | if utils.is_main_process(): 187 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 188 | f.write("best epoch: %d"%best_epoch) 189 | 190 | total_time = time.time() - start_time 191 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 192 | print('Training time {}'.format(total_time_str)) 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('--config', default='./configs/nlvr.yaml') 198 | parser.add_argument('--output_dir', default='output/NLVR') 199 | parser.add_argument('--evaluate', action='store_true') 200 | parser.add_argument('--device', default='cuda') 201 | parser.add_argument('--seed', default=42, type=int) 202 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 203 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 204 | parser.add_argument('--distributed', default=True, type=bool) 205 | args = parser.parse_args() 206 | 207 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 208 | 209 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 210 | 211 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 212 | 213 | main(args, config) -------------------------------------------------------------------------------- /model/BLIP/train_vqa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.utils.data import DataLoader 22 | import torch.backends.cudnn as cudnn 23 | import torch.distributed as dist 24 | 25 | from models.blip_vqa import blip_vqa 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | from data.vqa_dataset import vqa_collate_fn 30 | from data.utils import save_result 31 | 32 | 33 | def train(model, data_loader, optimizer, epoch, device): 34 | # train 35 | model.train() 36 | 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 39 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 40 | 41 | header = 'Train Epoch: [{}]'.format(epoch) 42 | print_freq = 50 43 | 44 | for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True) 46 | 47 | loss = model(image, question, answer, train=True, n=n, weights=weights) 48 | 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | metric_logger.update(loss=loss.item()) 54 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 55 | 56 | # gather the stats from all processes 57 | metric_logger.synchronize_between_processes() 58 | print("Averaged stats:", metric_logger.global_avg()) 59 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 60 | 61 | 62 | @torch.no_grad() 63 | def evaluation(model, data_loader, device, config) : 64 | # test 65 | model.eval() 66 | 67 | metric_logger = utils.MetricLogger(delimiter=" ") 68 | header = 'Generate VQA test result:' 69 | print_freq = 50 70 | 71 | result = [] 72 | 73 | if config['inference']=='rank': 74 | answer_list = data_loader.dataset.answer_list 75 | answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device) 76 | answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id 77 | 78 | for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 79 | image = image.to(device,non_blocking=True) 80 | 81 | if config['inference']=='generate': 82 | answers = model(image, question, train=False, inference='generate') 83 | 84 | for answer, ques_id in zip(answers, question_id): 85 | ques_id = int(ques_id.item()) 86 | result.append({"question_id":ques_id, "answer":answer}) 87 | 88 | elif config['inference']=='rank': 89 | answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test']) 90 | 91 | for ques_id, answer_id in zip(question_id, answer_ids): 92 | result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]}) 93 | 94 | return result 95 | 96 | 97 | def main(args, config): 98 | utils.init_distributed_mode(args) 99 | 100 | device = torch.device(args.device) 101 | 102 | # fix the seed for reproducibility 103 | seed = args.seed + utils.get_rank() 104 | torch.manual_seed(seed) 105 | np.random.seed(seed) 106 | random.seed(seed) 107 | cudnn.benchmark = True 108 | 109 | #### Dataset #### 110 | print("Creating vqa datasets") 111 | datasets = create_dataset('vqa', config) 112 | 113 | if args.distributed: 114 | num_tasks = utils.get_world_size() 115 | global_rank = utils.get_rank() 116 | samplers = create_sampler(datasets, [True, False], num_tasks, global_rank) 117 | else: 118 | samplers = [None, None] 119 | 120 | train_loader, test_loader = create_loader(datasets,samplers, 121 | batch_size=[config['batch_size_train'],config['batch_size_test']], 122 | num_workers=[4,4],is_trains=[True, False], 123 | collate_fns=[vqa_collate_fn,None]) 124 | #### Model #### 125 | print("Creating model") 126 | model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'], 127 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 128 | 129 | model = model.to(device) 130 | 131 | model_without_ddp = model 132 | if args.distributed: 133 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 134 | model_without_ddp = model.module 135 | 136 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 137 | 138 | best = 0 139 | best_epoch = 0 140 | 141 | print("Start training") 142 | start_time = time.time() 143 | for epoch in range(0, config['max_epoch']): 144 | if not args.evaluate: 145 | if args.distributed: 146 | train_loader.sampler.set_epoch(epoch) 147 | 148 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 149 | 150 | train_stats = train(model, train_loader, optimizer, epoch, device) 151 | 152 | else: 153 | break 154 | 155 | if utils.is_main_process(): 156 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 157 | 'epoch': epoch, 158 | } 159 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 160 | f.write(json.dumps(log_stats) + "\n") 161 | 162 | save_obj = { 163 | 'model': model_without_ddp.state_dict(), 164 | 'optimizer': optimizer.state_dict(), 165 | 'config': config, 166 | 'epoch': epoch, 167 | } 168 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 169 | 170 | dist.barrier() 171 | 172 | vqa_result = evaluation(model_without_ddp, test_loader, device, config) 173 | result_file = save_result(vqa_result, args.result_dir, 'vqa_result') 174 | 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print('Training time {}'.format(total_time_str)) 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--config', default='./configs/vqa.yaml') 184 | parser.add_argument('--output_dir', default='output/VQA') 185 | parser.add_argument('--evaluate', action='store_true') 186 | parser.add_argument('--device', default='cuda') 187 | parser.add_argument('--seed', default=42, type=int) 188 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 189 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 190 | parser.add_argument('--distributed', default=True, type=bool) 191 | args = parser.parse_args() 192 | 193 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 194 | 195 | args.result_dir = os.path.join(args.output_dir, 'result') 196 | 197 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 198 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 199 | 200 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 201 | 202 | main(args, config) -------------------------------------------------------------------------------- /model/BLIP/transform/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /model/BLIP/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 3 | """Decay the learning rate""" 4 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] = lr 7 | 8 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 9 | """Warmup the learning rate""" 10 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | 14 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 15 | """Decay the learning rate""" 16 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | 20 | import numpy as np 21 | import io 22 | import os 23 | import time 24 | from collections import defaultdict, deque 25 | import datetime 26 | 27 | import torch 28 | import torch.distributed as dist 29 | 30 | class SmoothedValue(object): 31 | """Track a series of values and provide access to smoothed values over a 32 | window or the global series average. 33 | """ 34 | 35 | def __init__(self, window_size=20, fmt=None): 36 | if fmt is None: 37 | fmt = "{median:.4f} ({global_avg:.4f})" 38 | self.deque = deque(maxlen=window_size) 39 | self.total = 0.0 40 | self.count = 0 41 | self.fmt = fmt 42 | 43 | def update(self, value, n=1): 44 | self.deque.append(value) 45 | self.count += n 46 | self.total += value * n 47 | 48 | def synchronize_between_processes(self): 49 | """ 50 | Warning: does not synchronize the deque! 51 | """ 52 | if not is_dist_avail_and_initialized(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 55 | dist.barrier() 56 | dist.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self): 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self): 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self): 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self): 81 | return self.deque[-1] 82 | 83 | def __str__(self): 84 | return self.fmt.format( 85 | median=self.median, 86 | avg=self.avg, 87 | global_avg=self.global_avg, 88 | max=self.max, 89 | value=self.value) 90 | 91 | 92 | class MetricLogger(object): 93 | def __init__(self, delimiter="\t"): 94 | self.meters = defaultdict(SmoothedValue) 95 | self.delimiter = delimiter 96 | 97 | def update(self, **kwargs): 98 | for k, v in kwargs.items(): 99 | if isinstance(v, torch.Tensor): 100 | v = v.item() 101 | assert isinstance(v, (float, int)) 102 | self.meters[k].update(v) 103 | 104 | def __getattr__(self, attr): 105 | if attr in self.meters: 106 | return self.meters[attr] 107 | if attr in self.__dict__: 108 | return self.__dict__[attr] 109 | raise AttributeError("'{}' object has no attribute '{}'".format( 110 | type(self).__name__, attr)) 111 | 112 | def __str__(self): 113 | loss_str = [] 114 | for name, meter in self.meters.items(): 115 | loss_str.append( 116 | "{}: {}".format(name, str(meter)) 117 | ) 118 | return self.delimiter.join(loss_str) 119 | 120 | def global_avg(self): 121 | loss_str = [] 122 | for name, meter in self.meters.items(): 123 | loss_str.append( 124 | "{}: {:.4f}".format(name, meter.global_avg) 125 | ) 126 | return self.delimiter.join(loss_str) 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | def add_meter(self, name, meter): 133 | self.meters[name] = meter 134 | 135 | def log_every(self, iterable, print_freq, header=None): 136 | i = 0 137 | if not header: 138 | header = '' 139 | start_time = time.time() 140 | end = time.time() 141 | iter_time = SmoothedValue(fmt='{avg:.4f}') 142 | data_time = SmoothedValue(fmt='{avg:.4f}') 143 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 144 | log_msg = [ 145 | header, 146 | '[{0' + space_fmt + '}/{1}]', 147 | 'eta: {eta}', 148 | '{meters}', 149 | 'time: {time}', 150 | 'data: {data}' 151 | ] 152 | if torch.cuda.is_available(): 153 | log_msg.append('max mem: {memory:.0f}') 154 | log_msg = self.delimiter.join(log_msg) 155 | MB = 1024.0 * 1024.0 156 | for obj in iterable: 157 | data_time.update(time.time() - end) 158 | yield obj 159 | iter_time.update(time.time() - end) 160 | if i % print_freq == 0 or i == len(iterable) - 1: 161 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 162 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 163 | if torch.cuda.is_available(): 164 | print(log_msg.format( 165 | i, len(iterable), eta=eta_string, 166 | meters=str(self), 167 | time=str(iter_time), data=str(data_time), 168 | memory=torch.cuda.max_memory_allocated() / MB)) 169 | else: 170 | print(log_msg.format( 171 | i, len(iterable), eta=eta_string, 172 | meters=str(self), 173 | time=str(iter_time), data=str(data_time))) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print('{} Total time: {} ({:.4f} s / it)'.format( 179 | header, total_time_str, total_time / len(iterable))) 180 | 181 | 182 | class AttrDict(dict): 183 | def __init__(self, *args, **kwargs): 184 | super(AttrDict, self).__init__(*args, **kwargs) 185 | self.__dict__ = self 186 | 187 | 188 | def compute_acc(logits, label, reduction='mean'): 189 | ret = (torch.argmax(logits, dim=1) == label).float() 190 | if reduction == 'none': 191 | return ret.detach() 192 | elif reduction == 'mean': 193 | return ret.mean().item() 194 | 195 | def compute_n_params(model, return_str=True): 196 | tot = 0 197 | for p in model.parameters(): 198 | w = 1 199 | for x in p.shape: 200 | w *= x 201 | tot += w 202 | if return_str: 203 | if tot >= 1e6: 204 | return '{:.1f}M'.format(tot / 1e6) 205 | else: 206 | return '{:.1f}K'.format(tot / 1e3) 207 | else: 208 | return tot 209 | 210 | def setup_for_distributed(is_master): 211 | """ 212 | This function disables printing when not in master process 213 | """ 214 | import builtins as __builtin__ 215 | builtin_print = __builtin__.print 216 | 217 | def print(*args, **kwargs): 218 | force = kwargs.pop('force', False) 219 | if is_master or force: 220 | builtin_print(*args, **kwargs) 221 | 222 | __builtin__.print = print 223 | 224 | 225 | def is_dist_avail_and_initialized(): 226 | if not dist.is_available(): 227 | return False 228 | if not dist.is_initialized(): 229 | return False 230 | return True 231 | 232 | 233 | def get_world_size(): 234 | if not is_dist_avail_and_initialized(): 235 | return 1 236 | return dist.get_world_size() 237 | 238 | 239 | def get_rank(): 240 | if not is_dist_avail_and_initialized(): 241 | return 0 242 | return dist.get_rank() 243 | 244 | 245 | def is_main_process(): 246 | return get_rank() == 0 247 | 248 | 249 | def save_on_master(*args, **kwargs): 250 | if is_main_process(): 251 | torch.save(*args, **kwargs) 252 | 253 | 254 | def init_distributed_mode(args): 255 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 256 | args.rank = int(os.environ["RANK"]) 257 | args.world_size = int(os.environ['WORLD_SIZE']) 258 | args.gpu = int(os.environ['LOCAL_RANK']) 259 | elif 'SLURM_PROCID' in os.environ: 260 | args.rank = int(os.environ['SLURM_PROCID']) 261 | args.gpu = args.rank % torch.cuda.device_count() 262 | else: 263 | print('Not using distributed mode') 264 | args.distributed = False 265 | return 266 | 267 | args.distributed = True 268 | 269 | torch.cuda.set_device(args.gpu) 270 | args.dist_backend = 'nccl' 271 | print('| distributed init (rank {}, word {}): {}'.format( 272 | args.rank, args.world_size, args.dist_url), flush=True) 273 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 274 | world_size=args.world_size, rank=args.rank) 275 | torch.distributed.barrier() 276 | setup_for_distributed(args.rank == 0) 277 | 278 | -------------------------------------------------------------------------------- /model/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /model/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Code-kunkun/ZS-CIR/5b2a48518ccaef3dc2c2edc72db0523c8053506c/model/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /model/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /model/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.clip import clip 4 | import torch.nn.functional as F 5 | from model.BLIP.models.blip_retrieval import blip_retrieval 6 | 7 | 8 | class TransAgg(nn.Module): 9 | def __init__(self, cfg): 10 | super().__init__() 11 | self.device = cfg.device 12 | self.model_name = cfg.model_name 13 | if self.model_name == 'blip': 14 | self.pretrained_model = blip_retrieval(pretrained="/GPFS/data/yikunliu/cache/model_base_retrieval_coco.pth") 15 | self.feature_dim = 256 16 | elif self.model_name == 'clip-Vit-B/32': 17 | self.pretrained_model, self.preprocess = clip.load("/GPFS/data/yikunliu/cache/ViT-B-32.pt", device=cfg.device, jit=False) 18 | self.feature_dim = self.pretrained_model.visual.output_dim 19 | elif self.model_name == 'clip-Vit-L/14': 20 | self.pretrained_model, self.preprocess = clip.load("/GPFS/data/yikunliu/cache/ViT-L-14.pt", device=cfg.device, jit=False) 21 | self.feature_dim = self.pretrained_model.visual.output_dim 22 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.feature_dim, nhead=8, dropout=cfg.dropout, batch_first=True, norm_first=True, activation="gelu") 23 | self.fusion = nn.TransformerEncoder(encoder_layer, num_layers=cfg.num_layers) 24 | self.logit_scale = 100 25 | self.dropout = nn.Dropout(cfg.dropout) 26 | self.combiner_layer = nn.Linear(self.feature_dim + self.feature_dim, (self.feature_dim + self.feature_dim) * 4) 27 | self.weighted_layer = nn.Linear(self.feature_dim, 3) 28 | self.output_layer = nn.Linear((self.feature_dim + self.feature_dim) * 4, self.feature_dim) 29 | self.sep_token = nn.Parameter(torch.randn(1, 1, self.feature_dim)) 30 | 31 | 32 | def forward(self, texts, reference_images, target_images): 33 | img_text_rep = self.combine_features(reference_images, texts) 34 | target_features, _ = self.pretrained_model.encode_image(target_images) 35 | target_features = F.normalize(target_features, dim=-1) 36 | logits = self.logit_scale * (img_text_rep @ target_features.T) 37 | return logits 38 | 39 | def combine_features(self, reference_images, texts): 40 | reference_image_features, reference_total_image_features = self.pretrained_model.encode_image(reference_images, return_local=True) 41 | batch_size = reference_image_features.size(0) 42 | reference_total_image_features = reference_total_image_features.float() 43 | if self.model_name.startswith('blip'): 44 | tokenized_texts = self.pretrained_model.tokenizer(texts, padding='max_length', truncation=True, max_length=35, 45 | return_tensors='pt').to(self.device) 46 | mask = (tokenized_texts.attention_mask == 0) 47 | elif self.model_name.startswith('clip'): 48 | tokenized_texts = clip.tokenize(texts, truncate=True).to(reference_image_features.device) 49 | mask = (tokenized_texts == 0) 50 | 51 | text_features, total_text_features = self.pretrained_model.encode_text(tokenized_texts) 52 | 53 | num_patches = reference_total_image_features.size(1) 54 | sep_token = self.sep_token.repeat(batch_size, 1, 1) 55 | 56 | combine_features = torch.cat((total_text_features, sep_token, reference_total_image_features), dim=1) 57 | 58 | image_mask = torch.zeros(batch_size, num_patches + 1).to(reference_image_features.device) 59 | mask = torch.cat((mask, image_mask), dim=1) 60 | 61 | img_text_rep = self.fusion(combine_features, src_key_padding_mask=mask) 62 | 63 | if self.model_name.startswith('blip'): 64 | multimodal_img_rep = img_text_rep[:, 36, :] 65 | multimodal_text_rep = img_text_rep[:, 0, :] 66 | elif self.model_name.startswith('clip'): 67 | multimodal_img_rep = img_text_rep[:, 78, :] 68 | multimodal_text_rep = img_text_rep[torch.arange(batch_size), tokenized_texts.argmax(dim=-1), :] 69 | 70 | concate = torch.cat((multimodal_img_rep, multimodal_text_rep), dim=-1) 71 | f_U = self.output_layer(self.dropout(F.relu(self.combiner_layer(concate)))) 72 | weighted = self.weighted_layer(f_U) # (batch_size, 3) 73 | 74 | query_rep = weighted[:, 0:1] * text_features + weighted[:, 1:2] * f_U + weighted[:, 2:3] * reference_image_features 75 | 76 | query_rep = F.normalize(query_rep, dim=-1) 77 | 78 | return query_rep 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | accelerate==0.17.1 3 | aiofiles==23.1.0 4 | aiohttp==3.8.3 5 | aiosignal==1.2.0 6 | altair==4.2.2 7 | antlr4-python3-runtime==4.9.3 8 | anyio==3.5.0 9 | appdirs==1.4.4 10 | argon2-cffi-bindings==21.2.0 11 | argon2-cffi==21.3.0 12 | asttokens==2.0.5 13 | async-timeout==4.0.2 14 | attrs==22.1.0 15 | backcall==0.2.0 16 | beautifulsoup4==4.11.1 17 | bert-serving-client==1.10.0 18 | bert-serving-server==1.10.0 19 | bitsandbytes==0.37.2 20 | bleach==4.1.0 21 | blessings==1.7 22 | blinker==1.4 23 | brotlipy==0.7.0 24 | cachetools==4.2.2 25 | certifi==2022.12.7 26 | cffi==1.15.1 27 | charset-normalizer==2.1.1 28 | click==8.0.4 29 | clip==1.0 30 | comet-ml==3.32.2 31 | comm==0.1.2 32 | configobj==5.0.8 33 | contourpy==1.0.7 34 | cryptography==38.0.4 35 | cycler==0.11.0 36 | debugpy==1.5.1 37 | decorator==5.1.1 38 | defusedxml==0.7.1 39 | docker-pycreds==0.4.0 40 | dulwich==0.21.3 41 | einops==0.6.0 42 | entrypoints==0.4 43 | everett==3.1.0 44 | executing==0.8.3 45 | fairscale==0.4.13 46 | faiss==1.7.3 47 | fastapi==0.95.0 48 | fastjsonschema==2.16.2 49 | ffmpy==0.3.0 50 | filelock==3.9.0 51 | fire==0.5.0 52 | flit-core==3.6.0 53 | fonttools==4.38.0 54 | frozenlist==1.3.3 55 | fsspec==2023.1.0 56 | ftfy==5.8 57 | fvcore==0.1.5.post20221221 58 | gitdb==4.0.10 59 | gitpython==3.1.31 60 | google-auth-oauthlib==0.4.4 61 | google-auth==2.6.0 62 | gpustat==0.6.0 63 | gputil==1.4.0 64 | gradio==3.23.0 65 | grpcio==1.42.0 66 | h11==0.14.0 67 | hiq-python==1.1.9 68 | httpcore==0.16.3 69 | httpx==0.23.3 70 | huggingface-hub==0.13.3 71 | hydra-core==1.3.2 72 | idna==3.4 73 | importlib-metadata==4.11.3 74 | importlib-resources==5.12.0 75 | iopath==0.1.10 76 | ipykernel==6.19.2 77 | ipython-genutils==0.2.0 78 | ipython==8.10.0 79 | ipywidgets==7.6.5 80 | jedi==0.18.1 81 | jinja2==3.1.2 82 | joblib==1.2.0 83 | jsonschema==4.17.3 84 | jupyter-client==7.4.9 85 | jupyter-core==5.2.0 86 | jupyter-server==1.23.4 87 | jupyterlab-pygments==0.1.2 88 | jupyterlab-widgets==1.0.0 89 | kiwisolver==1.4.4 90 | linkify-it-py==2.0.0 91 | lxml==4.9.1 92 | markdown-it-py==2.2.0 93 | markdown==3.4.1 94 | markupsafe==2.1.1 95 | matplotlib-inline==0.1.6 96 | matplotlib==3.7.0 97 | mdit-py-plugins==0.3.3 98 | mdurl==0.1.2 99 | mistune==0.8.4 100 | mkl-fft==1.3.1 101 | mkl-random==1.2.2 102 | mkl-service==2.4.0 103 | multidict==6.0.2 104 | nbclassic==0.5.2 105 | nbclient==0.5.13 106 | nbconvert==6.5.4 107 | nbformat==5.7.0 108 | nest-asyncio==1.5.6 109 | nltk==3.8.1 110 | notebook-shim==0.2.2 111 | notebook==6.5.2 112 | numpy==1.24.2 113 | nvidia-ml-py3==7.352.0 114 | oauthlib==3.2.1 115 | omegaconf==2.3.0 116 | openai==0.27.4 117 | opencv-python==4.7.0.72 118 | orjson==3.8.8 119 | packaging==22.0 120 | pandas==1.5.3 121 | pandocfilters==1.5.0 122 | parso==0.8.3 123 | pathtools==0.1.2 124 | pexpect==4.8.0 125 | pickleshare==0.7.5 126 | pillow==9.4.0 127 | pip==22.3.1 128 | platformdirs==2.5.2 129 | portalocker==2.7.0 130 | prometheus-client==0.14.1 131 | prompt-toolkit==3.0.36 132 | protobuf==3.20.3 133 | psutil==5.9.0 134 | ptyprocess==0.7.0 135 | pure-eval==0.2.2 136 | py-itree==0.0.18 137 | pyasn1-modules==0.2.8 138 | pyasn1==0.4.8 139 | pycocotools==2.0.6 140 | pycparser==2.21 141 | pydantic==1.10.7 142 | pydub==0.25.1 143 | pygments==2.11.2 144 | pyjwt==2.4.0 145 | pyllama==0.0.8 146 | pyopenssl==22.0.0 147 | pyparsing==3.0.9 148 | pyrsistent==0.18.0 149 | pysocks==1.7.1 150 | python-box==6.1.0 151 | python-dateutil==2.8.2 152 | python-multipart==0.0.6 153 | pytz==2022.7.1 154 | pyyaml==6.0 155 | pyzmq==23.2.0 156 | regex==2022.10.31 157 | requests-oauthlib==1.3.0 158 | requests-toolbelt==0.10.1 159 | requests==2.28.2 160 | rfc3986==1.5.0 161 | rsa==4.7.2 162 | scikit-learn==1.2.2 163 | scipy==1.10.1 164 | semantic-version==2.10.0 165 | send2trash==1.8.0 166 | sentence-transformers==2.2.2 167 | sentencepiece==0.1.97 168 | sentry-sdk==1.15.0 169 | setproctitle==1.3.2 170 | setuptools==65.6.3 171 | simplejson==3.18.3 172 | six==1.16.0 173 | smmap==5.0.0 174 | sniffio==1.2.0 175 | soupsieve==2.3.2.post1 176 | stack-data==0.2.0 177 | starlette==0.26.1 178 | tabulate==0.9.0 179 | tenacity==8.2.2 180 | tensorboard-data-server==0.6.1 181 | tensorboard-plugin-wit==1.8.1 182 | tensorboard==2.10.0 183 | termcolor==2.2.0 184 | terminado==0.17.1 185 | threadpoolctl==3.1.0 186 | timm==0.6.12 187 | tinycss2==1.2.1 188 | tokenizers==0.13.2 189 | toolz==0.12.0 190 | torch==1.12.1+cu113 191 | torchmetrics==0.11.4 192 | torchvision==0.13.1+cu113 193 | tornado==6.2 194 | tqdm==4.64.1 195 | traitlets==5.7.1 196 | transformers==4.27.3 197 | typing-extensions==4.5.0 198 | uc-micro-py==1.0.1 199 | urllib3==1.26.15 200 | uvicorn==0.21.1 201 | wandb==0.13.10 202 | warmup-scheduler==0.3 203 | wcwidth==0.2.5 204 | webencodings==0.5.1 205 | websocket-client==0.58.0 206 | websockets==10.4 207 | werkzeug==2.2.2 208 | wheel==0.38.4 209 | widgetsnbextension==3.5.2 210 | wrapt==1.14.1 211 | wurlitzer==3.0.3 212 | yacs==0.1.8 213 | yarl==1.8.1 214 | zipp==3.11.0 -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 2 | import PIL 3 | import torchvision.transforms.functional as Ft 4 | 5 | def _convert_image_to_rgb(image): 6 | return image.convert("RGB") 7 | 8 | class SquarePad: 9 | """ 10 | Square pad the input image with zero padding 11 | """ 12 | 13 | def __init__(self, size: int): 14 | """ 15 | For having a consistent preprocess pipeline with CLIP we need to have the preprocessing output dimension as 16 | a parameter 17 | :param size: preprocessing output dimension 18 | """ 19 | self.size = size 20 | 21 | def __call__(self, image): 22 | w, h = image.size 23 | max_wh = max(w, h) 24 | hp = int((max_wh - w) / 2) 25 | vp = int((max_wh - h) / 2) 26 | padding = [hp, vp, hp, vp] 27 | return Ft.pad(image, padding, 0, 'constant') 28 | 29 | 30 | class TargetPad: 31 | """ 32 | Pad the image if its aspect ratio is above a target ratio. 33 | Pad the image to match such target ratio 34 | """ 35 | 36 | def __init__(self, target_ratio: float, size: int): 37 | """ 38 | :param target_ratio: target ratio 39 | :param size: preprocessing output dimension 40 | """ 41 | self.size = size 42 | self.target_ratio = target_ratio 43 | 44 | def __call__(self, image): 45 | w, h = image.size 46 | actual_ratio = max(w, h) / min(w, h) 47 | if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio 48 | return image 49 | scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio 50 | hp = max(int((scaled_max_wh - w) / 2), 0) 51 | vp = max(int((scaled_max_wh - h) / 2), 0) 52 | padding = [hp, vp, hp, vp] 53 | return Ft.pad(image, padding, 0, 'constant') 54 | 55 | 56 | def squarepad_transform(dim: int): 57 | """ 58 | CLIP-like preprocessing transform on a square padded image 59 | :param dim: image output dimension 60 | :return: CLIP-like torchvision Compose transform 61 | """ 62 | return Compose([ 63 | SquarePad(dim), 64 | Resize(dim, interpolation=PIL.Image.BICUBIC), 65 | CenterCrop(dim), 66 | _convert_image_to_rgb, 67 | ToTensor(), 68 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 69 | ]) 70 | 71 | 72 | def targetpad_transform(target_ratio: float, dim: int): 73 | """ 74 | CLIP-like preprocessing transform computed after using TargetPad pad 75 | :param target_ratio: target ratio for TargetPad 76 | :param dim: image output dimension 77 | :return: CLIP-like torchvision Compose transform 78 | """ 79 | return Compose([ 80 | TargetPad(target_ratio, dim), 81 | Resize(dim, interpolation=PIL.Image.BICUBIC), 82 | CenterCrop(dim), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from tqdm import tqdm 4 | import random 5 | from torch import optim 6 | from model.model import TransAgg 7 | from transform import targetpad_transform, squarepad_transform 8 | from data.cirr_dataset import CIRRDataset 9 | from data.fiq_dataset import FashionIQDataset 10 | from data.laion_dataset_template import LaionDataset_Template 11 | from data.laion_dataset_llm import LaionDataset_LLM 12 | from data.laion_dataset_combined import LaionDataset_Combined 13 | 14 | def get_model(cfg): 15 | model = TransAgg(cfg) 16 | model = model.to(cfg.device) 17 | return model 18 | 19 | def set_grad(cfg, model): 20 | if cfg.encoder == 'text': 21 | print('Only the text encoder will be fine-tuned') 22 | if cfg.model_name.startswith("blip"): 23 | for param in model.pretrained_model.visual_encoder.parameters(): 24 | param.requires_grad = False 25 | for param in model.pretrained_model.vision_proj.parameters(): 26 | param.requires_grad = False 27 | elif cfg.model_name.startswith('clip'): 28 | for param in model.pretrained_model.visual.parameters(): 29 | param.requires_grad = False 30 | elif cfg.encoder == 'both': 31 | print('Both encoders will be fine-tuned') 32 | elif cfg.encoder == 'neither': 33 | for param in model.pretrained_model.parameters(): 34 | param.requires_grad = False 35 | else: 36 | raise ValueError("encoder parameter should be in ['text', 'both', 'neither']") 37 | 38 | 39 | def get_preprocess(cfg, model, input_dim): 40 | if cfg.transform == "clip": 41 | preprocess = model.preprocess 42 | print('CLIP default preprocess pipeline is used') 43 | elif cfg.transform == "squarepad": 44 | preprocess = squarepad_transform(input_dim) 45 | print('Square pad preprocess pipeline is used') 46 | elif cfg.transform == "targetpad": 47 | target_ratio = cfg.target_ratio 48 | preprocess = targetpad_transform(target_ratio, input_dim) 49 | print(f'Target pad with {target_ratio = } preprocess pipeline is used') 50 | else: 51 | raise ValueError("Preprocess transform should be in ['clip', 'squarepad', 'targetpad']") 52 | 53 | return preprocess 54 | 55 | def get_laion_cirr_dataset(preprocess, laion_type): 56 | relative_val_dataset = CIRRDataset('val', 'relative', preprocess) 57 | classic_val_dataset = CIRRDataset('val', 'classic', preprocess) 58 | 59 | if laion_type == 'laion_template': 60 | relative_train_dataset = LaionDataset_Template('train', preprocess) 61 | elif laion_type == 'laion_llm': 62 | relative_train_dataset = LaionDataset_LLM('train', preprocess) 63 | elif laion_type == 'laion_combined': 64 | relative_train_dataset = LaionDataset_Combined('train', preprocess) 65 | else: 66 | raise ValueError("laion_type should be in ['laion_template', 'laion_llm', 'laion_combined']") 67 | 68 | return relative_train_dataset, relative_val_dataset, classic_val_dataset 69 | 70 | def get_laion_fiq_dataset(preprocess, val_dress_types, laion_type): 71 | 72 | if laion_type == 'laion_template': 73 | relative_train_dataset = LaionDataset_Template('train', preprocess) 74 | elif laion_type == 'laion_llm': 75 | relative_train_dataset = LaionDataset_LLM('train', preprocess) 76 | elif laion_type == 'laion_combined': 77 | relative_train_dataset = LaionDataset_Combined('train', preprocess) 78 | else: 79 | raise ValueError("laion_type should be in ['laion_template', 'laion_llm', 'laion_combined']") 80 | 81 | idx_to_dress_mapping = {} 82 | relative_val_datasets = [] 83 | classic_val_datasets = [] 84 | for idx, dress_type in enumerate(val_dress_types): 85 | idx_to_dress_mapping[idx] = dress_type 86 | relative_val_dataset = FashionIQDataset('val', [dress_type], 'relative', preprocess) 87 | relative_val_datasets.append(relative_val_dataset) 88 | classic_val_dataset = FashionIQDataset('val', [dress_type], 'classic', preprocess) 89 | classic_val_datasets.append(classic_val_dataset) 90 | return relative_train_dataset, relative_val_datasets, classic_val_datasets, idx_to_dress_mapping 91 | 92 | 93 | def collate_fn(batch: list): 94 | """ 95 | Discard None images in a batch when using torch DataLoader 96 | :param batch: input_batch 97 | :return: output_batch = input_batch - None_values 98 | """ 99 | batch = list(filter(lambda x: x is not None, batch)) 100 | return torch.utils.data.dataloader.default_collate(batch) 101 | 102 | 103 | def extract_index_features(dataset, model, return_local=True): 104 | feature_dim = model.feature_dim 105 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=8, 106 | pin_memory=True, collate_fn=collate_fn) 107 | index_features = torch.empty((0, feature_dim)).to(model.device, non_blocking=True) 108 | index_total_features = [] 109 | index_names = [] 110 | 111 | for names, images in tqdm(classic_val_loader): 112 | images = images.to(model.device, non_blocking=True) 113 | with torch.no_grad(): 114 | batch_features, batch_total_features = model.pretrained_model.encode_image(images, return_local) 115 | index_features = torch.vstack((index_features, batch_features)) 116 | index_total_features.append(batch_total_features) 117 | index_names.extend(names) 118 | if return_local: 119 | with torch.no_grad(): 120 | index_total_features = torch.cat(index_total_features, dim=0).to(model.device, non_blocking=True) 121 | else: 122 | index_total_features = None 123 | return index_features, index_names, index_total_features 124 | 125 | 126 | def get_optimizer(model, cfg): 127 | pretrained_params = list(map(id, model.pretrained_model.parameters())) 128 | optimizer_grouped_parameters = [ 129 | {'params': [p for n, p in model.named_parameters() if p.requires_grad and id(p) not in pretrained_params], 'weight_decay': cfg.weight_decay}, 130 | {'params': [p for n, p in model.named_parameters() if p.requires_grad and id(p) in pretrained_params], 'weight_decay': cfg.weight_decay, 'lr': 1e-6}, 131 | ] 132 | 133 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon) 134 | return optimizer 135 | 136 | 137 | def update_train_running_results(train_running_results: dict, loss: torch.tensor, images_in_batch: int): 138 | train_running_results['accumulated_train_loss'] += loss.item() * images_in_batch 139 | train_running_results["images_in_epoch"] += images_in_batch 140 | 141 | def set_train_bar_description(train_bar, epoch: int, num_epochs: int, train_running_results: dict): 142 | if train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch'] < 0: 143 | print(train_running_results['accumulated_train_loss'], train_running_results['images_in_epoch']) 144 | train_bar.set_description( 145 | desc=f"[{epoch}/{num_epochs}] " 146 | f"train loss : {train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch']:.3f} " 147 | ) 148 | 149 | def generate_randomized_fiq_caption(flattened_captions: list[str]) -> list[str]: 150 | """ 151 | Function which randomize the FashionIQ training captions in four way: (a) cap1 and cap2 (b) cap2 and cap1 (c) cap1 152 | (d) cap2 153 | :param flattened_captions: the list of caption to randomize, note that the length of such list is 2*batch_size since 154 | to each triplet are associated two captions 155 | :return: the randomized caption list (with length = batch_size) 156 | """ 157 | captions = [] 158 | for i in range(0, len(flattened_captions), 2): 159 | random_num = random.random() 160 | if random_num < 0.25: 161 | captions.append( 162 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}") 163 | elif 0.25 < random_num < 0.5: 164 | captions.append( 165 | f"{flattened_captions[i + 1].strip('.?, ').capitalize()} and {flattened_captions[i].strip('.?, ')}") 166 | elif 0.5 < random_num < 0.75: 167 | captions.append(f"{flattened_captions[i].strip('.?, ').capitalize()}") 168 | else: 169 | captions.append(f"{flattened_captions[i + 1].strip('.?, ').capitalize()}") 170 | return captions 171 | 172 | --------------------------------------------------------------------------------