├── .gitignore ├── .vscode └── settings.json ├── README.md ├── __init__.py ├── adv_encode.py ├── blip_2_node_test.py ├── blip_node.py ├── cog.yaml ├── configs ├── bert_config.json ├── caption_coco.yaml ├── med_config.json ├── nlvr.yaml ├── nocaps.yaml ├── pretrain.yaml ├── retrieval_coco.yaml ├── retrieval_flickr.yaml ├── retrieval_msrvtt.yaml └── vqa.yaml ├── data ├── __init__.py ├── coco_karpathy_dataset.py ├── flickr30k_dataset.py ├── nlvr_dataset.py ├── nocaps_dataset.py ├── pretrain_dataset.py ├── utils.py ├── video_dataset.py └── vqa_dataset.py ├── eval_nocaps.py ├── eval_retrieval_video.py ├── example_node.py ├── models ├── __init__.py ├── blip.py ├── blip_itm.py ├── blip_nlvr.py ├── blip_pretrain.py ├── blip_retrieval.py ├── blip_vqa.py ├── med.py ├── nlvr_encoder.py └── vit.py ├── predict.py ├── pretrain.py ├── requirements.txt ├── train_caption.py ├── train_nlvr.py ├── train_retrieval.py ├── train_vqa.py ├── transform └── randaugment.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | models/__pycache__/ 3 | *.pyc 4 | *.pth -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.defaultFormatter": "ms-python.black-formatter" 4 | }, 5 | "python.formatting.provider": "none" 6 | } 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A ComfyUI Node for adding BLIP in CLIPTextEncode 2 | 3 | ## Announcement: [BLIP](https://github.com/salesforce/BLIP) is now officially integrated into CLIPTextEncode 4 | 5 | ### Dependencies 6 | - [x] Fairscale>=0.4.4 (**NOT** in ComfyUI) 7 | - [x] Transformers==4.26.1 (already in ComfyUI) 8 | - [x] Timm>=0.4.12 (already in ComfyUI) 9 | - [x] Gitpython (already in ComfyUI) 10 | 11 | ### Local Installation 12 | Inside ComfyUI_windows_portable\python_embeded, run: 13 |
python.exe -m pip install fairscale14 | 15 | And, inside ComfyUI_windows_portable\ComfyUI\custom_nodes\, run: 16 |
git clone https://github.com/paulo-coronado/comfy_clip_blip_node17 | 18 | ### Google Colab Installation 19 | Add a cell with the following code: 20 |
21 | !pip install fairscale 22 | !cd custom_nodes && git clone https://github.com/paulo-coronado/comfy_clip_blip_node 23 |24 | 25 | ### How to use 26 | 1. Add the CLIPTextEncodeBLIP node; 27 | 2. Connect the node with an image and select a value for min_length and max_length; 28 | 3. Optional: if you want to embed the BLIP text in a prompt, use the keyword **BLIP_TEXT** (e.g. "a photo of BLIP_TEXT", medium shot, intricate details, highly detailed). 29 | 30 | ### Acknowledgement 31 | The implementation of **CLIPTextEncodeBLIP** relies on resources from BLIP, ALBEF, Huggingface Transformers, and timm. We thank the original authors for their open-sourcing. 32 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .blip_node import NODE_CLASS_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS'] 4 | -------------------------------------------------------------------------------- /adv_encode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import itertools 4 | 5 | def _grouper(n, iterable): 6 | it = iter(iterable) 7 | while True: 8 | chunk = list(itertools.islice(it, n)) 9 | if not chunk: 10 | return 11 | yield chunk 12 | 13 | def _norm_mag(w, n): 14 | d = w - 1 15 | return 1 + np.sign(d) * np.sqrt(np.abs(d)**2 / n) 16 | #return np.sign(w) * np.sqrt(np.abs(w)**2 / n) 17 | 18 | def divide_length(word_ids, weights): 19 | sums = dict(zip(*np.unique(word_ids, return_counts=True))) 20 | sums[0] = 1 21 | weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0 22 | for w, id in zip(x, y)] for x, y in zip(weights, word_ids)] 23 | return weights 24 | 25 | def shift_mean_weight(word_ids, weights): 26 | delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x,y) if id != 0]) 27 | weights = [[w if id == 0 else w+delta 28 | for w, id in zip(x, y)] for x, y in zip(weights, word_ids)] 29 | return weights 30 | 31 | def scale_to_norm(weights, word_ids, w_max): 32 | top = np.max(weights) 33 | w_max = min(top, w_max) 34 | weights = [[w_max if id == 0 else (w/top) * w_max 35 | for w, id in zip(x, y)] for x, y in zip(weights, word_ids)] 36 | return weights 37 | 38 | def from_zero(weights, base_emb): 39 | weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device) 40 | weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape) 41 | return base_emb * weight_tensor 42 | 43 | def mask_word_id(tokens, word_ids, target_id, mask_token): 44 | new_tokens = [[mask_token if wid == target_id else t 45 | for t, wid in zip(x,y)] for x,y in zip(tokens, word_ids)] 46 | mask = np.array(word_ids) == target_id 47 | return (new_tokens, mask) 48 | 49 | def batched_clip_encode(tokens, clip, num_chunks): 50 | embs = [] 51 | for e in _grouper(32, tokens): 52 | enc = clip.encode_from_tokens(e) 53 | enc = enc.reshape((len(e), clip.tokenizer.max_length, -1)) 54 | embs.append(enc) 55 | embs = torch.cat(embs) 56 | embs = embs.reshape((len(tokens) // num_chunks, clip.tokenizer.max_length * num_chunks, -1)) 57 | return embs 58 | 59 | def from_masked(tokens, weights, word_ids, base_emb, clip): 60 | wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True) 61 | weight_dict = dict((id,w) 62 | for id,w in zip(wids ,np.array(weights).reshape(-1)[inds]) 63 | if w != 1.0) 64 | 65 | if len(weight_dict) == 0: 66 | return torch.zeros_like(base_emb) 67 | 68 | weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device) 69 | weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape) 70 | 71 | #m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0) 72 | #TODO: find most suitable masking token here 73 | m_token = (266, 1.0) 74 | 75 | masked_tokens = [] 76 | masks = [] 77 | 78 | #create prompts 79 | for id, w in weight_dict.items(): 80 | masked, m = mask_word_id(tokens, word_ids, id, m_token) 81 | masked_tokens.extend(masked) 82 | 83 | m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device) 84 | m = m.reshape(1,-1,1).expand(base_emb.shape) 85 | masks.append(m) 86 | 87 | #batch process prompts 88 | embs = batched_clip_encode(masked_tokens, clip, len(tokens)) 89 | masks = torch.cat(masks) 90 | 91 | embs = (base_emb.expand(embs.shape) - embs) * masks 92 | embs = embs.sum(axis=0, keepdim=True) 93 | return ((weight_tensor - 1) * embs) 94 | 95 | def mask_inds(tokens, inds, mask_token): 96 | clip_len = len(tokens[0]) 97 | inds_set = set(inds) 98 | new_tokens = [[mask_token if i*clip_len + j in inds_set else t 99 | for j, t in enumerate(x)] for i, x in enumerate(tokens)] 100 | return new_tokens 101 | 102 | def down_weight(tokens, weights, word_ids, base_emb, clip): 103 | w, w_inv = np.unique(weights,return_inverse=True) 104 | 105 | if np.sum(w < 1) == 0: 106 | return base_emb 107 | #m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0) 108 | #using the comma token as a masking token seems to work better than aos tokens for SD 1.x 109 | m_token = (266, 1.0) 110 | 111 | masked_tokens = [] 112 | 113 | masked_current = tokens 114 | for i in range(len(w)): 115 | if w[i] >= 1: 116 | continue 117 | masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token) 118 | masked_tokens.extend(masked_current) 119 | 120 | embs = batched_clip_encode(masked_tokens, clip, len(tokens)) 121 | embs = torch.cat([base_emb, embs]) 122 | w = w[w<=1.0] 123 | w_mix = np.diff([0] + w.tolist()) 124 | w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1,1,1)) 125 | 126 | weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True) 127 | return weighted_emb 128 | 129 | def scale_emb_to_mag(base_emb, weighted_emb): 130 | norm_base = torch.linalg.norm(base_emb) 131 | norm_weighted = torch.linalg.norm(weighted_emb) 132 | embeddings_final = (norm_base / norm_weighted) * weighted_emb 133 | return embeddings_final 134 | 135 | def recover_dist(base_emb, weighted_emb): 136 | fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean()) 137 | embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean()) 138 | return embeddings_final 139 | 140 | def A1111_renorm(base_emb, weighted_emb): 141 | embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb 142 | return embeddings_final 143 | 144 | def advanced_encode_from_tokens(clip, tokenized, token_normalization, weight_interpretation, w_max=1.0): 145 | tokens = [[t for t,_,_ in x] for x in tokenized] 146 | weights = [[w for _,w,_ in x] for x in tokenized] 147 | word_ids = [[wid for _,_,wid in x] for x in tokenized] 148 | 149 | #weight normalization 150 | #==================== 151 | 152 | #distribute down/up weights over word lengths 153 | if token_normalization.startswith("length"): 154 | weights = divide_length(word_ids, weights) 155 | 156 | #make mean of word tokens 1 157 | if token_normalization.endswith("mean"): 158 | weights = shift_mean_weight(word_ids, weights) 159 | 160 | #weight interpretation 161 | #===================== 162 | 163 | if weight_interpretation == "comfy": 164 | weighted_tokens = [[(t,w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)] 165 | weighted_emb = clip.encode_from_tokens(weighted_tokens) 166 | else: 167 | unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokenized] 168 | base_emb = clip.encode_from_tokens(unweighted_tokens) 169 | 170 | if weight_interpretation == "A1111": 171 | weighted_emb = from_zero(weights, base_emb) 172 | weighted_emb = A1111_renorm(base_emb, weighted_emb) 173 | 174 | if weight_interpretation == "compel": 175 | pos_tokens = [[(t,w) if w >= 1.0 else (t,1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)] 176 | weighted_emb = clip.encode_from_tokens(pos_tokens) 177 | weighted_emb = down_weight(pos_tokens, weights, word_ids, weighted_emb, clip) 178 | 179 | if weight_interpretation == "comfy++": 180 | weighted_emb = down_weight(unweighted_tokens, weights, word_ids, base_emb, clip) 181 | weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights] 182 | weighted_emb += from_masked(unweighted_tokens, weights, word_ids, base_emb, clip) 183 | 184 | if weight_interpretation == "down_weight": 185 | weights = scale_to_norm(weights, word_ids, w_max) 186 | weighted_emb = down_weight(unweighted_tokens, weights, word_ids, base_emb, clip) 187 | 188 | return weighted_emb 189 | 190 | def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0): 191 | tokenized = clip.tokenize(text, return_word_ids=True) 192 | return advanced_encode_from_tokens(clip, tokenized, token_normalization, weight_interpretation, w_max) -------------------------------------------------------------------------------- /blip_2_node_test.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image # Add this line 7 | 8 | 9 | # Freeze PIP modules 10 | def packages(versions=False): 11 | import subprocess 12 | import sys 13 | return [( r.decode().split('==')[0] if not versions else r.decode() ) for r in subprocess.check_output([sys.executable, '-m', 'pip', 'freeze']).split()] 14 | 15 | # Tensor to PIL 16 | def tensor2pil(image): 17 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 18 | 19 | # Convert PIL to Tensor 20 | def pil2tensor(image): 21 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 22 | 23 | class BlipConcat: 24 | def __init__(self): 25 | pass 26 | 27 | @classmethod 28 | def INPUT_TYPES(s): 29 | return { 30 | "required": { 31 | "clip": ("CLIP",), 32 | "image": ("IMAGE",), 33 | "max_length": ("INT", { 34 | "default": 0, 35 | "min": 0, # minimum value 36 | "max": 500, # maximum value 37 | "step": 1 # slider's step 38 | }), 39 | "string_field": ("STRING", { 40 | "multiline": True, #True if you want the field to look like the one on the ClipTextEncode node 41 | "default": "{{BLIP_TEXT}}" 42 | }), 43 | }, 44 | } 45 | 46 | RETURN_TYPES = ("CONDITIONING",) 47 | 48 | FUNCTION = "blip" 49 | 50 | CATEGORY = "conditioning" 51 | 52 | def blip(self, clip, image, max_length, string_field): 53 | # Check if Transformers is installed and update it to the latest version from GitHub 54 | if 'transformers' not in packages(): 55 | print("\033[34mBLIP-2:\033[0m Installing transformers...") 56 | subprocess.check_call([sys.executable, '-m', 'pip', '-q', 'install', '--upgrade', 'git+https://github.com/huggingface/transformers.git']) 57 | 58 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 59 | 60 | image = tensor2pil(image).resize((596, 437)) 61 | 62 | # Load model 63 | processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") 64 | model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) 65 | 66 | # Move model to GPU if available 67 | device = "cuda" if torch.cuda.is_available() else "cpu" 68 | model.to(device) 69 | 70 | # Encode text 71 | inputs = processor(image, return_tensors="pt").to(device, torch.float16) 72 | 73 | print("\033[34mBLIP-2:\033[0m Generating text...") 74 | 75 | # Generate text 76 | generated_ids = model.generate(**inputs, max_new_tokens=max_length) 77 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 78 | 79 | print(generated_text) 80 | 81 | prompt = 'photo of a dining room with a table and chairs' 82 | 83 | return ([[clip.encode(prompt), {}]], ) 84 | 85 | 86 | # A dictionary that contains all nodes you want to export with their names 87 | # NOTE: names should be globally unique 88 | NODE_CLASS_MAPPINGS = { 89 | "CLIPTextEncodeBLIP-2": BlipConcat 90 | } -------------------------------------------------------------------------------- /blip_node.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.transforms.functional import InterpolationMode 9 | from .adv_encode import advanced_encode 10 | 11 | # Add the ComfyUI directory to the system path 12 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) 13 | sys.path.append(".." + os.sep + "ComfyUI") 14 | 15 | NODE_FILE = os.path.abspath(__file__) 16 | BLIP_NODE_ROOT = os.path.dirname(NODE_FILE) 17 | MODELS_DIR = os.path.join( 18 | ( 19 | os.getcwd() + os.sep + "ComfyUI" 20 | if not os.getcwd().startswith(("/content", "/workspace")) 21 | else os.getcwd() 22 | ), 23 | "models", 24 | ) 25 | 26 | 27 | # Freeze PIP modules 28 | def packages(versions=False): 29 | import subprocess 30 | import sys 31 | 32 | return [ 33 | (r.decode().split("==")[0] if not versions else r.decode()) 34 | for r in subprocess.check_output( 35 | [sys.executable, "-m", "pip", "freeze"] 36 | ).split() 37 | ] 38 | 39 | 40 | # Tensor to PIL 41 | def tensor2pil(image): 42 | return Image.fromarray( 43 | np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) 44 | ) 45 | 46 | 47 | # Convert PIL to Tensor 48 | def pil2tensor(image): 49 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 50 | 51 | 52 | def transformImage_legacy(input_image, image_size, device): 53 | raw_image = input_image.convert("RGB") 54 | raw_image = raw_image.resize((image_size, image_size)) 55 | transform = transforms.Compose( 56 | [ 57 | transforms.Resize(raw_image.size, interpolation=InterpolationMode.BICUBIC), 58 | transforms.ToTensor(), 59 | transforms.Normalize( 60 | (0.48145466, 0.4578275, 0.40821073), 61 | (0.26862954, 0.26130258, 0.27577711), 62 | ), 63 | ] 64 | ) 65 | image = transform(raw_image).unsqueeze(0).to(device) 66 | return image 67 | 68 | 69 | def transformImage(input_image, image_size, device): 70 | raw_image = input_image.convert("RGB") 71 | raw_image = raw_image.resize((image_size, image_size)) 72 | transform = transforms.Compose( 73 | [ 74 | transforms.Resize(raw_image.size, interpolation=InterpolationMode.BICUBIC), 75 | transforms.ToTensor(), 76 | transforms.Normalize( 77 | (0.48145466, 0.4578275, 0.40821073), 78 | (0.26862954, 0.26130258, 0.27577711), 79 | ), 80 | ] 81 | ) 82 | image = transform(raw_image).unsqueeze(0).to(device) 83 | return image.view( 84 | 1, -1, image_size, image_size 85 | ) # Change the shape of the output tensor 86 | 87 | 88 | class BlipConcat: 89 | def __init__(self): 90 | pass 91 | 92 | @classmethod 93 | def INPUT_TYPES(s): 94 | return { 95 | "required": { 96 | "clip": ("CLIP",), 97 | "image": ("IMAGE",), 98 | "min_length": ( 99 | "INT", 100 | { 101 | "default": 5, 102 | "min": 0, # minimum value 103 | "max": 200, # maximum value 104 | "step": 1, # slider's step 105 | }, 106 | ), 107 | "max_length": ( 108 | "INT", 109 | { 110 | "default": 20, 111 | "min": 0, # minimum value 112 | "max": 200, # maximum value 113 | "step": 1, # slider's step 114 | }, 115 | ), 116 | "token_normalization": (["none", "mean", "length", "length+mean"],), 117 | "weight_interpretation": (["comfy", "A1111", "compel", "comfy++"],), 118 | "string_field": ( 119 | "STRING", 120 | { 121 | "multiline": True, # True if you want the field to look like the one on the ClipTextEncode node 122 | "default": "{{BLIP_TEXT}}", 123 | }, 124 | ), 125 | }, 126 | } 127 | 128 | RETURN_TYPES = ("CONDITIONING",) 129 | 130 | FUNCTION = "blip" 131 | 132 | CATEGORY = "conditioning" 133 | 134 | def blip( 135 | self, 136 | clip, 137 | image, 138 | min_length, 139 | max_length, 140 | token_normalization, 141 | weight_interpretation, 142 | string_field, 143 | ): 144 | print(f"\033[34mStarting BLIP...\033[0m") 145 | 146 | # Change the current working directory to BLIP_NODE_ROOT 147 | os.chdir(BLIP_NODE_ROOT) 148 | 149 | # Add BLIP_NODE_ROOT to the Python path 150 | sys.path.insert(0, BLIP_NODE_ROOT) 151 | 152 | from models.blip import blip_decoder 153 | 154 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 155 | 156 | image = tensor2pil(image) 157 | size = 384 158 | 159 | if "transformers==4.26.1" in packages(True): 160 | print("Using Legacy `transformImaage()`") 161 | tensor = transformImage_legacy(image, size, device) 162 | else: 163 | tensor = transformImage(image, size, device) 164 | 165 | blip_dir = os.path.join(MODELS_DIR, "blip") 166 | if not os.path.exists(blip_dir): 167 | os.mkdir(blip_dir) 168 | 169 | torch.hub.set_dir(blip_dir) 170 | 171 | model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" 172 | 173 | model = blip_decoder(pretrained=model_url, image_size=size, vit="base") 174 | model.eval() 175 | model = model.to(device) 176 | 177 | with torch.no_grad(): 178 | caption = model.generate( 179 | tensor, 180 | sample=False, 181 | num_beams=1, 182 | min_length=min_length, 183 | max_length=max_length, 184 | ) 185 | text = string_field.replace("BLIP_TEXT", caption[0]) 186 | print(f"\033[34mPrompt:\033[0m", text) 187 | 188 | # Encode text with custom weights 189 | embeddings_final = advanced_encode( 190 | clip, text, token_normalization, weight_interpretation, w_max=1.0 191 | ) 192 | 193 | return ([[embeddings_final, {}]],) 194 | 195 | 196 | NODE_CLASS_MAPPINGS = {"CLIPTextEncodeBLIP": BlipConcat} 197 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | cuda: "11.1" 4 | python_version: "3.8" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "ipython==7.30.1" 10 | - "torchvision==0.11.1" 11 | - "torch==1.10.0" 12 | - "timm==0.4.12" 13 | - "transformers==4.15.0" 14 | - "fairscale==0.4.4" 15 | - "pycocoevalcap==1.2" 16 | 17 | predict: "predict.py:Predictor" 18 | -------------------------------------------------------------------------------- /configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | coco_gt_root: 'annotation/coco_gt' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 7 | 8 | # size of vit model; base or large 9 | vit: 'base' 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | batch_size: 32 13 | init_lr: 1e-5 14 | 15 | # vit: 'large' 16 | # vit_grad_ckpt: True 17 | # vit_ckpt_layer: 5 18 | # batch_size: 16 19 | # init_lr: 2e-6 20 | 21 | image_size: 384 22 | 23 | # generation configs 24 | max_length: 20 25 | min_length: 5 26 | num_beams: 3 27 | prompt: 'a picture of ' 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 5 33 | 34 | -------------------------------------------------------------------------------- /configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/nlvr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/NLVR2/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth' 6 | 7 | #size of vit model; base or large 8 | vit: 'base' 9 | batch_size_train: 16 10 | batch_size_test: 64 11 | vit_grad_ckpt: False 12 | vit_ckpt_layer: 0 13 | max_epoch: 15 14 | 15 | image_size: 384 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-5 20 | min_lr: 0 21 | 22 | -------------------------------------------------------------------------------- /configs/nocaps.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/nocaps/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 6 | 7 | vit: 'base' 8 | batch_size: 32 9 | 10 | image_size: 384 11 | 12 | max_length: 20 13 | min_length: 5 14 | num_beams: 3 15 | prompt: 'a picture of ' -------------------------------------------------------------------------------- /configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json', 2 | '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json', 3 | ] 4 | laion_path: '' 5 | 6 | # size of vit model; base or large 7 | vit: 'base' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 224 12 | batch_size: 75 13 | 14 | queue_size: 57600 15 | alpha: 0.4 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-4 20 | min_lr: 1e-6 21 | warmup_lr: 1e-6 22 | lr_decay_rate: 0.9 23 | max_epoch: 20 24 | warmup_steps: 3000 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /configs/retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | dataset: 'coco' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 12 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 256 28 | negative_all_rank: True 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /configs/retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/flickr30k/' 2 | ann_root: 'annotation' 3 | dataset: 'flickr' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 10 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 128 28 | negative_all_rank: False 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /configs/retrieval_msrvtt.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 6 | 7 | # size of vit model; base or large 8 | vit: 'base' 9 | batch_size: 64 10 | k_test: 128 11 | image_size: 384 12 | num_frm_test: 8 -------------------------------------------------------------------------------- /configs/vqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/ 2 | vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/ 3 | train_files: ['vqa_train','vqa_val','vg_qa'] 4 | ann_root: 'annotation' 5 | 6 | # set pretrained as a file path or an url 7 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 8 | 9 | # size of vit model; base or large 10 | vit: 'base' 11 | batch_size_train: 16 12 | batch_size_test: 32 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | init_lr: 2e-5 16 | 17 | image_size: 480 18 | 19 | k_test: 128 20 | inference: 'rank' 21 | 22 | # optimizer 23 | weight_decay: 0.05 24 | min_lr: 0 25 | max_epoch: 10 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | 16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) 34 | return dataset 35 | 36 | elif dataset=='caption_coco': 37 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 38 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 39 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 40 | return train_dataset, val_dataset, test_dataset 41 | 42 | elif dataset=='nocaps': 43 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 44 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 45 | return val_dataset, test_dataset 46 | 47 | elif dataset=='retrieval_coco': 48 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 49 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 50 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 51 | return train_dataset, val_dataset, test_dataset 52 | 53 | elif dataset=='retrieval_flickr': 54 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 55 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 56 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 57 | return train_dataset, val_dataset, test_dataset 58 | 59 | elif dataset=='vqa': 60 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 61 | train_files = config['train_files'], split='train') 62 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 63 | return train_dataset, test_dataset 64 | 65 | elif dataset=='nlvr': 66 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 67 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 68 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 69 | return train_dataset, val_dataset, test_dataset 70 | 71 | 72 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 73 | samplers = [] 74 | for dataset,shuffle in zip(datasets,shuffles): 75 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 76 | samplers.append(sampler) 77 | return samplers 78 | 79 | 80 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 81 | loaders = [] 82 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 83 | if is_train: 84 | shuffle = (sampler is None) 85 | drop_last = True 86 | else: 87 | shuffle = False 88 | drop_last = False 89 | loader = DataLoader( 90 | dataset, 91 | batch_size=bs, 92 | num_workers=n_worker, 93 | pin_memory=True, 94 | sampler=sampler, 95 | shuffle=shuffle, 96 | collate_fn=collate_fn, 97 | drop_last=drop_last, 98 | ) 99 | loaders.append(loader) 100 | return loaders 101 | 102 | -------------------------------------------------------------------------------- /data/coco_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class coco_karpathy_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. coco/images/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 18 | filename = 'coco_karpathy_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class coco_karpathy_caption_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split): 54 | ''' 55 | image_root (string): Root directory of images (e.g. coco/images/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 61 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | def __len__(self): 70 | return len(self.annotation) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.image_root,ann['image']) 77 | image = Image.open(image_path).convert('RGB') 78 | image = self.transform(image) 79 | 80 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 81 | 82 | return image, int(img_id) 83 | 84 | 85 | class coco_karpathy_retrieval_eval(Dataset): 86 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 87 | ''' 88 | image_root (string): Root directory of images (e.g. coco/images/) 89 | ann_root (string): directory to store the annotation file 90 | split (string): val or test 91 | ''' 92 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 93 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 94 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 95 | 96 | download_url(urls[split],ann_root) 97 | 98 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 99 | self.transform = transform 100 | self.image_root = image_root 101 | 102 | self.text = [] 103 | self.image = [] 104 | self.txt2img = {} 105 | self.img2txt = {} 106 | 107 | txt_id = 0 108 | for img_id, ann in enumerate(self.annotation): 109 | self.image.append(ann['image']) 110 | self.img2txt[img_id] = [] 111 | for i, caption in enumerate(ann['caption']): 112 | self.text.append(pre_caption(caption,max_words)) 113 | self.img2txt[img_id].append(txt_id) 114 | self.txt2img[txt_id] = img_id 115 | txt_id += 1 116 | 117 | def __len__(self): 118 | return len(self.annotation) 119 | 120 | def __getitem__(self, index): 121 | 122 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 123 | image = Image.open(image_path).convert('RGB') 124 | image = self.transform(image) 125 | 126 | return image, index -------------------------------------------------------------------------------- /data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class flickr30k_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. flickr30k/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' 18 | filename = 'flickr30k_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class flickr30k_retrieval_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 54 | ''' 55 | image_root (string): Root directory of images (e.g. flickr30k/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} 61 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | self.text = [] 70 | self.image = [] 71 | self.txt2img = {} 72 | self.img2txt = {} 73 | 74 | txt_id = 0 75 | for img_id, ann in enumerate(self.annotation): 76 | self.image.append(ann['image']) 77 | self.img2txt[img_id] = [] 78 | for i, caption in enumerate(ann['caption']): 79 | self.text.append(pre_caption(caption,max_words)) 80 | self.img2txt[img_id].append(txt_id) 81 | self.txt2img[txt_id] = img_id 82 | txt_id += 1 83 | 84 | def __len__(self): 85 | return len(self.annotation) 86 | 87 | def __getitem__(self, index): 88 | 89 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 90 | image = Image.open(image_path).convert('RGB') 91 | image = self.transform(image) 92 | 93 | return image, index -------------------------------------------------------------------------------- /data/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.utils import download_url 7 | 8 | from PIL import Image 9 | 10 | from data.utils import pre_caption 11 | 12 | class nlvr_dataset(Dataset): 13 | def __init__(self, transform, image_root, ann_root, split): 14 | ''' 15 | image_root (string): Root directory of images 16 | ann_root (string): directory to store the annotation file 17 | split (string): train, val or test 18 | ''' 19 | urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json', 20 | 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json', 21 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'} 22 | filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'} 23 | 24 | download_url(urls[split],ann_root) 25 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 26 | 27 | self.transform = transform 28 | self.image_root = image_root 29 | 30 | 31 | def __len__(self): 32 | return len(self.annotation) 33 | 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.annotation[index] 38 | 39 | image0_path = os.path.join(self.image_root,ann['images'][0]) 40 | image0 = Image.open(image0_path).convert('RGB') 41 | image0 = self.transform(image0) 42 | 43 | image1_path = os.path.join(self.image_root,ann['images'][1]) 44 | image1 = Image.open(image1_path).convert('RGB') 45 | image1 = self.transform(image1) 46 | 47 | sentence = pre_caption(ann['sentence'], 40) 48 | 49 | if ann['label']=='True': 50 | label = 1 51 | else: 52 | label = 0 53 | 54 | words = sentence.split(' ') 55 | 56 | if 'left' not in words and 'right' not in words: 57 | if random.random()<0.5: 58 | return image0, image1, sentence, label 59 | else: 60 | return image1, image0, sentence, label 61 | else: 62 | if random.random()<0.5: 63 | return image0, image1, sentence, label 64 | else: 65 | new_words = [] 66 | for word in words: 67 | if word=='left': 68 | new_words.append('right') 69 | elif word=='right': 70 | new_words.append('left') 71 | else: 72 | new_words.append(word) 73 | 74 | sentence = ' '.join(new_words) 75 | return image1, image0, sentence, label 76 | 77 | 78 | -------------------------------------------------------------------------------- /data/nocaps_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | class nocaps_eval(Dataset): 10 | def __init__(self, transform, image_root, ann_root, split): 11 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json', 12 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'} 13 | filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'} 14 | 15 | download_url(urls[split],ann_root) 16 | 17 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 18 | self.transform = transform 19 | self.image_root = image_root 20 | 21 | def __len__(self): 22 | return len(self.annotation) 23 | 24 | def __getitem__(self, index): 25 | 26 | ann = self.annotation[index] 27 | 28 | image_path = os.path.join(self.image_root,ann['image']) 29 | image = Image.open(image_path).convert('RGB') 30 | image = self.transform(image) 31 | 32 | return image, int(ann['img_id']) -------------------------------------------------------------------------------- /data/pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from data.utils import pre_caption 13 | import os,glob 14 | 15 | class pretrain_dataset(Dataset): 16 | def __init__(self, ann_file, laion_path, transform): 17 | 18 | self.ann_pretrain = [] 19 | for f in ann_file: 20 | print('loading '+f) 21 | ann = json.load(open(f,'r')) 22 | self.ann_pretrain += ann 23 | 24 | self.laion_path = laion_path 25 | if self.laion_path: 26 | self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) 27 | 28 | print('loading '+self.laion_files[0]) 29 | with open(self.laion_files[0],'r') as f: 30 | self.ann_laion = json.load(f) 31 | 32 | self.annotation = self.ann_pretrain + self.ann_laion 33 | else: 34 | self.annotation = self.ann_pretrain 35 | 36 | self.transform = transform 37 | 38 | 39 | def reload_laion(self, epoch): 40 | n = epoch%len(self.laion_files) 41 | print('loading '+self.laion_files[n]) 42 | with open(self.laion_files[n],'r') as f: 43 | self.ann_laion = json.load(f) 44 | 45 | self.annotation = self.ann_pretrain + self.ann_laion 46 | 47 | 48 | def __len__(self): 49 | return len(self.annotation) 50 | 51 | def __getitem__(self, index): 52 | 53 | ann = self.annotation[index] 54 | 55 | image = Image.open(ann['image']).convert('RGB') 56 | image = self.transform(image) 57 | caption = pre_caption(ann['caption'],30) 58 | 59 | return image, caption -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils 9 | 10 | def pre_caption(caption,max_words=50): 11 | caption = re.sub( 12 | r"([.!\"()*#:;~])", 13 | ' ', 14 | caption.lower(), 15 | ) 16 | caption = re.sub( 17 | r"\s{2,}", 18 | ' ', 19 | caption, 20 | ) 21 | caption = caption.rstrip('\n') 22 | caption = caption.strip(' ') 23 | 24 | #truncate caption 25 | caption_words = caption.split(' ') 26 | if len(caption_words)>max_words: 27 | caption = ' '.join(caption_words[:max_words]) 28 | 29 | return caption 30 | 31 | def pre_question(question,max_ques_words=50): 32 | question = re.sub( 33 | r"([.!\"()*#:;~])", 34 | '', 35 | question.lower(), 36 | ) 37 | question = question.rstrip(' ') 38 | 39 | #truncate question 40 | question_words = question.split(' ') 41 | if len(question_words)>max_ques_words: 42 | question = ' '.join(question_words[:max_ques_words]) 43 | 44 | return question 45 | 46 | 47 | def save_result(result, result_dir, filename, remove_duplicate=''): 48 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 49 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 50 | 51 | json.dump(result,open(result_file,'w')) 52 | 53 | dist.barrier() 54 | 55 | if utils.is_main_process(): 56 | # combine results from all processes 57 | result = [] 58 | 59 | for rank in range(utils.get_world_size()): 60 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 61 | res = json.load(open(result_file,'r')) 62 | result += res 63 | 64 | if remove_duplicate: 65 | result_new = [] 66 | id_list = [] 67 | for res in result: 68 | if res[remove_duplicate] not in id_list: 69 | id_list.append(res[remove_duplicate]) 70 | result_new.append(res) 71 | result = result_new 72 | 73 | json.dump(result,open(final_result_file,'w')) 74 | print('result file saved to %s'%final_result_file) 75 | 76 | return final_result_file 77 | 78 | 79 | 80 | from pycocotools.coco import COCO 81 | from pycocoevalcap.eval import COCOEvalCap 82 | from torchvision.datasets.utils import download_url 83 | 84 | def coco_caption_eval(coco_gt_root, results_file, split): 85 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 86 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 87 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 88 | 89 | download_url(urls[split],coco_gt_root) 90 | annotation_file = os.path.join(coco_gt_root,filenames[split]) 91 | 92 | # create coco object and coco_result object 93 | coco = COCO(annotation_file) 94 | coco_result = coco.loadRes(results_file) 95 | 96 | # create coco_eval object by taking coco and coco_result 97 | coco_eval = COCOEvalCap(coco, coco_result) 98 | 99 | # evaluate on a subset of images by setting 100 | # coco_eval.params['image_id'] = coco_result.getImgIds() 101 | # please remove this line when evaluating the full validation set 102 | # coco_eval.params['image_id'] = coco_result.getImgIds() 103 | 104 | # evaluate results 105 | # SPICE will take a few minutes the first time, but speeds up due to caching 106 | coco_eval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in coco_eval.eval.items(): 110 | print(f'{metric}: {score:.3f}') 111 | 112 | return coco_eval -------------------------------------------------------------------------------- /data/video_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision.datasets.utils import download_url 3 | 4 | from PIL import Image 5 | import torch 6 | import numpy as np 7 | import random 8 | import decord 9 | from decord import VideoReader 10 | import json 11 | import os 12 | from data.utils import pre_caption 13 | 14 | decord.bridge.set_bridge("torch") 15 | 16 | class ImageNorm(object): 17 | """Apply Normalization to Image Pixels on GPU 18 | """ 19 | def __init__(self, mean, std): 20 | self.mean = torch.tensor(mean).view(1, 3, 1, 1) 21 | self.std = torch.tensor(std).view(1, 3, 1, 1) 22 | 23 | def __call__(self, img): 24 | 25 | if torch.max(img) > 1 and self.mean.max() <= 1: 26 | img.div_(255.) 27 | return img.sub_(self.mean).div_(self.std) 28 | 29 | def load_jsonl(filename): 30 | with open(filename, "r") as f: 31 | return [json.loads(l.strip("\n")) for l in f.readlines()] 32 | 33 | 34 | class VideoDataset(Dataset): 35 | 36 | def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'): 37 | ''' 38 | image_root (string): Root directory of video 39 | ann_root (string): directory to store the annotation file 40 | ''' 41 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl' 42 | filename = 'msrvtt_test.jsonl' 43 | 44 | download_url(url,ann_root) 45 | self.annotation = load_jsonl(os.path.join(ann_root,filename)) 46 | 47 | self.num_frm = num_frm 48 | self.frm_sampling_strategy = frm_sampling_strategy 49 | self.max_img_size = max_img_size 50 | self.video_root = video_root 51 | self.video_fmt = video_fmt 52 | self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 53 | 54 | self.text = [pre_caption(ann['caption'],40) for ann in self.annotation] 55 | self.txt2video = [i for i in range(len(self.annotation))] 56 | self.video2txt = self.txt2video 57 | 58 | 59 | def __len__(self): 60 | return len(self.annotation) 61 | 62 | def __getitem__(self, index): 63 | 64 | ann = self.annotation[index] 65 | 66 | video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt) 67 | 68 | vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size) 69 | 70 | video = self.img_norm(vid_frm_array.float()) 71 | 72 | return video, ann['clip_name'] 73 | 74 | 75 | 76 | def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1): 77 | try: 78 | if not height or not width: 79 | vr = VideoReader(video_path) 80 | else: 81 | vr = VideoReader(video_path, width=width, height=height) 82 | 83 | vlen = len(vr) 84 | 85 | if start_time or end_time: 86 | assert fps > 0, 'must provide video fps if specifying start and end time.' 87 | 88 | start_idx = min(int(start_time * fps), vlen) 89 | end_idx = min(int(end_time * fps), vlen) 90 | else: 91 | start_idx, end_idx = 0, vlen 92 | 93 | if self.frm_sampling_strategy == 'uniform': 94 | frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int) 95 | elif self.frm_sampling_strategy == 'rand': 96 | frame_indices = sorted(random.sample(range(vlen), self.num_frm)) 97 | elif self.frm_sampling_strategy == 'headtail': 98 | frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2)) 99 | frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2)) 100 | frame_indices = frame_indices_head + frame_indices_tail 101 | else: 102 | raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy)) 103 | 104 | raw_sample_frms = vr.get_batch(frame_indices) 105 | except Exception as e: 106 | return None 107 | 108 | raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2) 109 | 110 | return raw_sample_frms 111 | -------------------------------------------------------------------------------- /data/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from data.utils import pre_question 9 | 10 | from torchvision.datasets.utils import download_url 11 | 12 | class vqa_dataset(Dataset): 13 | def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"): 14 | self.split = split 15 | 16 | self.transform = transform 17 | self.vqa_root = vqa_root 18 | self.vg_root = vg_root 19 | 20 | if split=='train': 21 | urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json', 22 | 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json', 23 | 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'} 24 | 25 | self.annotation = [] 26 | for f in train_files: 27 | download_url(urls[f],ann_root) 28 | self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r')) 29 | else: 30 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root) 31 | self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r')) 32 | 33 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root) 34 | self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r')) 35 | 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | if ann['dataset']=='vqa': 45 | image_path = os.path.join(self.vqa_root,ann['image']) 46 | elif ann['dataset']=='vg': 47 | image_path = os.path.join(self.vg_root,ann['image']) 48 | 49 | image = Image.open(image_path).convert('RGB') 50 | image = self.transform(image) 51 | 52 | if self.split == 'test': 53 | question = pre_question(ann['question']) 54 | question_id = ann['question_id'] 55 | return image, question, question_id 56 | 57 | 58 | elif self.split=='train': 59 | 60 | question = pre_question(ann['question']) 61 | 62 | if ann['dataset']=='vqa': 63 | answer_weight = {} 64 | for answer in ann['answer']: 65 | if answer in answer_weight.keys(): 66 | answer_weight[answer] += 1/len(ann['answer']) 67 | else: 68 | answer_weight[answer] = 1/len(ann['answer']) 69 | 70 | answers = list(answer_weight.keys()) 71 | weights = list(answer_weight.values()) 72 | 73 | elif ann['dataset']=='vg': 74 | answers = [ann['answer']] 75 | weights = [0.2] 76 | 77 | return image, question, answers, weights 78 | 79 | 80 | def vqa_collate_fn(batch): 81 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 82 | for image, question, answer, weights in batch: 83 | image_list.append(image) 84 | question_list.append(question) 85 | weight_list += weights 86 | answer_list += answer 87 | n.append(len(answer)) 88 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n -------------------------------------------------------------------------------- /eval_nocaps.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from data import create_dataset, create_sampler, create_loader 28 | from data.utils import save_result 29 | 30 | @torch.no_grad() 31 | def evaluate(model, data_loader, device, config): 32 | # evaluate 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | print_freq = 10 38 | 39 | result = [] 40 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 41 | 42 | image = image.to(device) 43 | 44 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 45 | min_length=config['min_length'], repetition_penalty=1.1) 46 | 47 | for caption, img_id in zip(captions, image_id): 48 | result.append({"image_id": img_id.item(), "caption": caption}) 49 | 50 | return result 51 | 52 | 53 | def main(args, config): 54 | utils.init_distributed_mode(args) 55 | 56 | device = torch.device(args.device) 57 | 58 | # fix the seed for reproducibility 59 | seed = args.seed + utils.get_rank() 60 | torch.manual_seed(seed) 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | cudnn.benchmark = True 64 | 65 | #### Dataset #### 66 | print("Creating captioning dataset") 67 | val_dataset, test_dataset = create_dataset('nocaps', config) 68 | 69 | if args.distributed: 70 | num_tasks = utils.get_world_size() 71 | global_rank = utils.get_rank() 72 | samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank) 73 | else: 74 | samplers = [None,None] 75 | 76 | val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers, 77 | batch_size=[config['batch_size']]*2,num_workers=[4,4], 78 | is_trains=[False, False], collate_fns=[None,None]) 79 | 80 | #### Model #### 81 | print("Creating model") 82 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 83 | prompt=config['prompt']) 84 | 85 | model = model.to(device) 86 | 87 | model_without_ddp = model 88 | if args.distributed: 89 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 90 | model_without_ddp = model.module 91 | 92 | val_result = evaluate(model_without_ddp, val_loader, device, config) 93 | val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id') 94 | test_result = evaluate(model_without_ddp, test_loader, device, config) 95 | test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id') 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--config', default='./configs/nocaps.yaml') 101 | parser.add_argument('--output_dir', default='output/NoCaps') 102 | parser.add_argument('--device', default='cuda') 103 | parser.add_argument('--seed', default=42, type=int) 104 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 106 | parser.add_argument('--distributed', default=True, type=bool) 107 | args = parser.parse_args() 108 | 109 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 110 | 111 | args.result_dir = os.path.join(args.output_dir, 'result') 112 | 113 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 114 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 115 | 116 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 117 | 118 | main(args, config) -------------------------------------------------------------------------------- /eval_retrieval_video.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_retrieval import blip_retrieval 26 | import utils 27 | from data.video_dataset import VideoDataset 28 | 29 | 30 | @torch.no_grad() 31 | def evaluation(model, data_loader, tokenizer, device, config): 32 | # test 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | 38 | print('Computing features for evaluation...') 39 | start_time = time.time() 40 | 41 | texts = data_loader.dataset.text 42 | num_text = len(texts) 43 | text_bs = 256 44 | text_ids = [] 45 | text_embeds = [] 46 | text_atts = [] 47 | for i in range(0, num_text, text_bs): 48 | text = texts[i: min(num_text, i+text_bs)] 49 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device) 50 | text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') 51 | text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:])) 52 | text_embeds.append(text_embed) 53 | text_ids.append(text_input.input_ids) 54 | text_atts.append(text_input.attention_mask) 55 | 56 | text_embeds = torch.cat(text_embeds,dim=0) 57 | text_ids = torch.cat(text_ids,dim=0) 58 | text_atts = torch.cat(text_atts,dim=0) 59 | text_ids[:,0] = tokenizer.additional_special_tokens_ids[0] 60 | 61 | video_feats = [] 62 | video_embeds = [] 63 | for video, video_id in data_loader: 64 | 65 | B,N,C,W,H = video.size() 66 | video = video.view(-1,C,W,H) 67 | video = video.to(device,non_blocking=True) 68 | video_feat = model.visual_encoder(video) 69 | video_embed = model.vision_proj(video_feat[:,0,:]) 70 | video_embed = video_embed.view(B,N,-1).mean(dim=1) 71 | video_embed = F.normalize(video_embed,dim=-1) 72 | 73 | video_feat = video_feat.view(B,-1,video_feat.shape[-1]) 74 | video_feats.append(video_feat.cpu()) 75 | video_embeds.append(video_embed) 76 | 77 | video_feats = torch.cat(video_feats,dim=0) 78 | video_embeds = torch.cat(video_embeds,dim=0) 79 | 80 | sims_matrix = video_embeds @ text_embeds.t() 81 | score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device) 82 | 83 | num_tasks = utils.get_world_size() 84 | rank = utils.get_rank() 85 | step = sims_matrix.size(0)//num_tasks + 1 86 | start = rank*step 87 | end = min(sims_matrix.size(0),start+step) 88 | 89 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 90 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 91 | 92 | encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True) 93 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) 94 | output = model.text_encoder(text_ids[topk_idx], 95 | attention_mask = text_atts[topk_idx], 96 | encoder_hidden_states = encoder_output, 97 | encoder_attention_mask = encoder_att, 98 | return_dict = True, 99 | ) 100 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 101 | score_matrix_v2t[start+i,topk_idx] = score + topk_sim 102 | 103 | sims_matrix = sims_matrix.t() 104 | score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device) 105 | 106 | step = sims_matrix.size(0)//num_tasks + 1 107 | start = rank*step 108 | end = min(sims_matrix.size(0),start+step) 109 | 110 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 111 | 112 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 113 | encoder_output = video_feats[topk_idx].to(device,non_blocking=True) 114 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) 115 | output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1), 116 | attention_mask = text_atts[start+i].repeat(config['k_test'],1), 117 | encoder_hidden_states = encoder_output, 118 | encoder_attention_mask = encoder_att, 119 | return_dict = True, 120 | ) 121 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 122 | score_matrix_t2v[start+i,topk_idx] = score + topk_sim 123 | 124 | if args.distributed: 125 | dist.barrier() 126 | torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM) 127 | torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM) 128 | 129 | total_time = time.time() - start_time 130 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 131 | print('Evaluation time {}'.format(total_time_str)) 132 | 133 | return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy() 134 | 135 | 136 | 137 | @torch.no_grad() 138 | def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt): 139 | 140 | #Video->Text 141 | ranks = np.zeros(scores_v2t.shape[0]) 142 | for index,score in enumerate(scores_v2t): 143 | inds = np.argsort(score)[::-1] 144 | ranks[index] = np.where(inds == vid2txt[index])[0][0] 145 | 146 | # Compute metrics 147 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 148 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 149 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 150 | 151 | #Text->Video 152 | ranks = np.zeros(scores_t2v.shape[0]) 153 | 154 | for index,score in enumerate(scores_t2v): 155 | inds = np.argsort(score)[::-1] 156 | ranks[index] = np.where(inds == txt2vmg[index])[0][0] 157 | 158 | mdR = np.median(ranks+1) 159 | 160 | # Compute metrics 161 | vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 162 | vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 163 | vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 164 | 165 | tr_mean = (tr1 + tr5 + tr10) / 3 166 | vr_mean = (vr1 + vr5 + vr10) / 3 167 | r_mean = (tr_mean + vr_mean) / 2 168 | 169 | eval_result = {'txt_r1': tr1, 170 | 'txt_r5': tr5, 171 | 'txt_r10': tr10, 172 | 'txt_r_mean': tr_mean, 173 | 'vid_r1': vr1, 174 | 'vid_r5': vr5, 175 | 'vid_r10': vr10, 176 | 'vid_r_mean': vr_mean, 177 | 'vid_mdR': mdR, 178 | 'r_mean': r_mean} 179 | return eval_result 180 | 181 | 182 | 183 | 184 | def main(args, config): 185 | utils.init_distributed_mode(args) 186 | 187 | device = torch.device(args.device) 188 | 189 | # fix the seed for reproducibility 190 | seed = args.seed + utils.get_rank() 191 | torch.manual_seed(seed) 192 | np.random.seed(seed) 193 | random.seed(seed) 194 | cudnn.benchmark = True 195 | 196 | #### Dataset #### 197 | print("Creating retrieval dataset") 198 | test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'], 199 | max_img_size=config['image_size'], frm_sampling_strategy='uniform') 200 | 201 | test_loader = DataLoader( 202 | test_dataset, 203 | batch_size=config['batch_size'], 204 | num_workers=4, 205 | pin_memory=True, 206 | drop_last=False, 207 | shuffle=False, 208 | ) 209 | 210 | #### Model #### 211 | print("Creating model") 212 | model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit']) 213 | 214 | model = model.to(device) 215 | 216 | model_without_ddp = model 217 | if args.distributed: 218 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 219 | model_without_ddp = model.module 220 | 221 | score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config) 222 | 223 | if utils.is_main_process(): 224 | 225 | test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt) 226 | print(test_result) 227 | 228 | log_stats = {**{f'{k}': v for k, v in test_result.items()},} 229 | with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f: 230 | f.write(json.dumps(log_stats) + "\n") 231 | 232 | 233 | if __name__ == '__main__': 234 | parser = argparse.ArgumentParser() 235 | parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml') 236 | parser.add_argument('--output_dir', default='output/Retrieval_msrvtt') 237 | parser.add_argument('--device', default='cuda') 238 | parser.add_argument('--seed', default=42, type=int) 239 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 240 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 241 | parser.add_argument('--distributed', default=True, type=bool) 242 | args = parser.parse_args() 243 | 244 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 245 | 246 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 247 | 248 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 249 | 250 | main(args, config) -------------------------------------------------------------------------------- /example_node.py: -------------------------------------------------------------------------------- 1 | class Example: 2 | """ 3 | A example node 4 | 5 | Class methods 6 | ------------- 7 | INPUT_TYPES (dict): 8 | Tell the main program input parameters of nodes. 9 | 10 | Attributes 11 | ---------- 12 | RETURN_TYPES (`tuple`): 13 | The type of each element in the output tulple. 14 | RETURN_NAMES (`tuple`): 15 | Optional: The name of each output in the output tulple. 16 | FUNCTION (`str`): 17 | The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() 18 | OUTPUT_NODE ([`bool`]): 19 | If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. 20 | The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. 21 | Assumed to be False if not present. 22 | CATEGORY (`str`): 23 | The category the node should appear in the UI. 24 | execute(s) -> tuple || None: 25 | The entry point method. The name of this method must be the same as the value of property `FUNCTION`. 26 | For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. 27 | """ 28 | def __init__(self): 29 | pass 30 | 31 | @classmethod 32 | def INPUT_TYPES(s): 33 | """ 34 | Return a dictionary which contains config for all input fields. 35 | Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". 36 | Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. 37 | The type can be a list for selection. 38 | 39 | Returns: `dict`: 40 | - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` 41 | - Value input_fields (`dict`): Contains input fields config: 42 | * Key field_name (`string`): Name of a entry-point method's argument 43 | * Value field_config (`tuple`): 44 | + First value is a string indicate the type of field or a list for selection. 45 | + Secound value is a config for type "INT", "STRING" or "FLOAT". 46 | """ 47 | return { 48 | "required": { 49 | "image": ("IMAGE",), 50 | "int_field": ("INT", { 51 | "default": 0, 52 | "min": 0, #Minimum value 53 | "max": 4096, #Maximum value 54 | "step": 64 #Slider's step 55 | }), 56 | "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 57 | "print_to_screen": (["enable", "disable"],), 58 | "string_field": ("STRING", { 59 | "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node 60 | "default": "Hello World!" 61 | }), 62 | }, 63 | } 64 | 65 | RETURN_TYPES = ("IMAGE",) 66 | #RETURN_NAMES = ("image_output_name",) 67 | 68 | FUNCTION = "test" 69 | 70 | #OUTPUT_NODE = False 71 | 72 | CATEGORY = "Example" 73 | 74 | def test(self, image, string_field, int_field, float_field, print_to_screen): 75 | if print_to_screen == "enable": 76 | print(f"""Your input contains: 77 | string_field aka input text: {string_field} 78 | int_field: {int_field} 79 | float_field: {float_field} 80 | """) 81 | #do some processing on the image, in this example I just invert it 82 | image = 1.0 - image 83 | return (image,) 84 | 85 | 86 | # A dictionary that contains all nodes you want to export with their names 87 | # NOTE: names should be globally unique 88 | NODE_CLASS_MAPPINGS = { 89 | "Example": Example 90 | } 91 | 92 | # A dictionary that contains the friendly/humanly readable titles for the nodes 93 | NODE_DISPLAY_NAME_MAPPINGS = { 94 | "Example": "Example Node" 95 | } 96 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paulo-coronado/comfy_clip_blip_node/a441ff42758d6092872221476f7ac7f9bdea8512/models/__init__.py -------------------------------------------------------------------------------- /models/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import warnings 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | import os 13 | from urllib.parse import urlparse 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from timm.models.hub import download_cached_file 18 | from torch import nn 19 | from transformers import BertTokenizer 20 | 21 | from models.med import BertConfig, BertLMHeadModel, BertModel 22 | from models.vit import VisionTransformer, interpolate_pos_embed 23 | 24 | 25 | class BLIP_Base(nn.Module): 26 | def __init__(self, 27 | med_config = 'configs/med_config.json', 28 | image_size = 224, 29 | vit = 'base', 30 | vit_grad_ckpt = False, 31 | vit_ckpt_layer = 0, 32 | ): 33 | """ 34 | Args: 35 | med_config (str): path for the mixture of encoder-decoder model's configuration file 36 | image_size (int): input image size 37 | vit (str): model size of vision transformer 38 | """ 39 | super().__init__() 40 | 41 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 42 | self.tokenizer = init_tokenizer() 43 | med_config = BertConfig.from_json_file(med_config) 44 | med_config.encoder_width = vision_width 45 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 46 | 47 | 48 | def forward(self, image, caption, mode): 49 | 50 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" 51 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 52 | 53 | if mode=='image': 54 | # return image features 55 | image_embeds = self.visual_encoder(image) 56 | return image_embeds 57 | 58 | elif mode=='text': 59 | # return text features 60 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 61 | return_dict = True, mode = 'text') 62 | return text_output.last_hidden_state 63 | 64 | elif mode=='multimodal': 65 | # return multimodel features 66 | image_embeds = self.visual_encoder(image) 67 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 68 | 69 | text.input_ids[:,0] = self.tokenizer.enc_token_id 70 | output = self.text_encoder(text.input_ids, 71 | attention_mask = text.attention_mask, 72 | encoder_hidden_states = image_embeds, 73 | encoder_attention_mask = image_atts, 74 | return_dict = True, 75 | ) 76 | return output.last_hidden_state 77 | 78 | class BLIP_Decoder(nn.Module): 79 | def __init__(self, 80 | med_config = 'configs/med_config.json', 81 | image_size = 384, 82 | vit = 'base', 83 | vit_grad_ckpt = False, 84 | vit_ckpt_layer = 0, 85 | prompt = 'a picture of ', 86 | ): 87 | """ 88 | Args: 89 | med_config (str): path for the mixture of encoder-decoder model's configuration file 90 | image_size (int): input image size 91 | vit (str): model size of vision transformer 92 | """ 93 | super().__init__() 94 | 95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 96 | self.tokenizer = init_tokenizer() 97 | med_config = BertConfig.from_json_file(med_config) 98 | med_config.encoder_width = vision_width 99 | self.text_decoder = BertLMHeadModel(config=med_config) 100 | 101 | self.prompt = prompt 102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 103 | 104 | 105 | def forward(self, image, caption): 106 | 107 | image_embeds = self.visual_encoder(image) 108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 109 | 110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 111 | 112 | text.input_ids[:,0] = self.tokenizer.bos_token_id 113 | 114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) 115 | decoder_targets[:,:self.prompt_length] = -100 116 | 117 | decoder_output = self.text_decoder(text.input_ids, 118 | attention_mask = text.attention_mask, 119 | encoder_hidden_states = image_embeds, 120 | encoder_attention_mask = image_atts, 121 | labels = decoder_targets, 122 | return_dict = True, 123 | ) 124 | loss_lm = decoder_output.loss 125 | 126 | return loss_lm 127 | 128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): 129 | image_embeds = self.visual_encoder(image) 130 | 131 | if not sample: 132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) 133 | 134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} 136 | 137 | prompt = [self.prompt] * image.size(0) 138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 139 | input_ids[:,0] = self.tokenizer.bos_token_id 140 | input_ids = input_ids[:, :-1] 141 | 142 | if sample: 143 | #nucleus sampling 144 | outputs = self.text_decoder.generate(input_ids=input_ids, 145 | max_length=max_length, 146 | min_length=min_length, 147 | do_sample=True, 148 | top_p=top_p, 149 | num_return_sequences=1, 150 | eos_token_id=self.tokenizer.sep_token_id, 151 | pad_token_id=self.tokenizer.pad_token_id, 152 | repetition_penalty=1.1, 153 | **model_kwargs) 154 | else: 155 | #beam search 156 | outputs = self.text_decoder.generate(input_ids=input_ids, 157 | max_length=max_length, 158 | min_length=min_length, 159 | num_beams=num_beams, 160 | eos_token_id=self.tokenizer.sep_token_id, 161 | pad_token_id=self.tokenizer.pad_token_id, 162 | repetition_penalty=repetition_penalty, 163 | **model_kwargs) 164 | 165 | captions = [] 166 | for output in outputs: 167 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 168 | captions.append(caption[len(self.prompt):]) 169 | return captions 170 | 171 | 172 | def blip_decoder(pretrained='',**kwargs): 173 | model = BLIP_Decoder(**kwargs) 174 | if pretrained: 175 | model,msg = load_checkpoint(model,pretrained) 176 | assert(len(msg.missing_keys)==0) 177 | return model 178 | 179 | def blip_feature_extractor(pretrained='',**kwargs): 180 | model = BLIP_Base(**kwargs) 181 | if pretrained: 182 | model,msg = load_checkpoint(model,pretrained) 183 | assert(len(msg.missing_keys)==0) 184 | return model 185 | 186 | def init_tokenizer(): 187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 191 | return tokenizer 192 | 193 | 194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 195 | 196 | assert vit in ['base', 'large'], "vit parameter must be base or large" 197 | if vit=='base': 198 | vision_width = 768 199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 201 | drop_path_rate=0 or drop_path_rate 202 | ) 203 | elif vit=='large': 204 | vision_width = 1024 205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 207 | drop_path_rate=0.1 or drop_path_rate 208 | ) 209 | return visual_encoder, vision_width 210 | 211 | def is_url(url_or_filename): 212 | parsed = urlparse(url_or_filename) 213 | return parsed.scheme in ("http", "https") 214 | 215 | def load_checkpoint(model,url_or_filename): 216 | if is_url(url_or_filename): 217 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 218 | checkpoint = torch.load(cached_file, map_location='cpu') 219 | elif os.path.isfile(url_or_filename): 220 | checkpoint = torch.load(url_or_filename, map_location='cpu') 221 | else: 222 | raise RuntimeError('checkpoint url or path is invalid') 223 | 224 | state_dict = checkpoint['model'] 225 | 226 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 227 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 228 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 229 | model.visual_encoder_m) 230 | for key in model.state_dict().keys(): 231 | if key in state_dict.keys(): 232 | if state_dict[key].shape!=model.state_dict()[key].shape: 233 | del state_dict[key] 234 | 235 | msg = model.load_state_dict(state_dict,strict=False) 236 | print('load checkpoint from %s'%url_or_filename) 237 | return model,msg 238 | 239 | -------------------------------------------------------------------------------- /models/blip_itm.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_ITM(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | ): 19 | """ 20 | Args: 21 | med_config (str): path for the mixture of encoder-decoder model's configuration file 22 | image_size (int): input image size 23 | vit (str): model size of vision transformer 24 | """ 25 | super().__init__() 26 | 27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 28 | self.tokenizer = init_tokenizer() 29 | med_config = BertConfig.from_json_file(med_config) 30 | med_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 32 | 33 | text_width = self.text_encoder.config.hidden_size 34 | 35 | self.vision_proj = nn.Linear(vision_width, embed_dim) 36 | self.text_proj = nn.Linear(text_width, embed_dim) 37 | 38 | self.itm_head = nn.Linear(text_width, 2) 39 | 40 | 41 | def forward(self, image, caption, match_head='itm'): 42 | 43 | image_embeds = self.visual_encoder(image) 44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 45 | 46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 47 | return_tensors="pt").to(image.device) 48 | 49 | 50 | if match_head=='itm': 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = image_embeds, 54 | encoder_attention_mask = image_atts, 55 | return_dict = True, 56 | ) 57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 58 | return itm_output 59 | 60 | elif match_head=='itc': 61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 62 | return_dict = True, mode = 'text') 63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 65 | 66 | sim = image_feat @ text_feat.t() 67 | return sim 68 | 69 | 70 | def blip_itm(pretrained='',**kwargs): 71 | model = BLIP_ITM(**kwargs) 72 | if pretrained: 73 | model,msg = load_checkpoint(model,pretrained) 74 | assert(len(msg.missing_keys)==0) 75 | return model 76 | -------------------------------------------------------------------------------- /models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig 2 | from models.nlvr_encoder import BertModel 3 | from models.vit import interpolate_pos_embed 4 | from models.blip import create_vit, init_tokenizer, is_url 5 | 6 | from timm.models.hub import download_cached_file 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from transformers import BertTokenizer 12 | import numpy as np 13 | 14 | class BLIP_NLVR(nn.Module): 15 | def __init__(self, 16 | med_config = 'configs/med_config.json', 17 | image_size = 480, 18 | vit = 'base', 19 | vit_grad_ckpt = False, 20 | vit_ckpt_layer = 0, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | self.cls_head = nn.Sequential( 37 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 38 | nn.ReLU(), 39 | nn.Linear(self.text_encoder.config.hidden_size, 2) 40 | ) 41 | 42 | def forward(self, image, text, targets, train=True): 43 | 44 | image_embeds = self.visual_encoder(image) 45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 46 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 47 | 48 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 49 | text.input_ids[:,0] = self.tokenizer.enc_token_id 50 | 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = [image0_embeds,image1_embeds], 54 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 55 | image_atts[image0_embeds.size(0):]], 56 | return_dict = True, 57 | ) 58 | hidden_state = output.last_hidden_state[:,0,:] 59 | prediction = self.cls_head(hidden_state) 60 | 61 | if train: 62 | loss = F.cross_entropy(prediction, targets) 63 | return loss 64 | else: 65 | return prediction 66 | 67 | def blip_nlvr(pretrained='',**kwargs): 68 | model = BLIP_NLVR(**kwargs) 69 | if pretrained: 70 | model,msg = load_checkpoint(model,pretrained) 71 | print("missing keys:") 72 | print(msg.missing_keys) 73 | return model 74 | 75 | 76 | def load_checkpoint(model,url_or_filename): 77 | if is_url(url_or_filename): 78 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 79 | checkpoint = torch.load(cached_file, map_location='cpu') 80 | elif os.path.isfile(url_or_filename): 81 | checkpoint = torch.load(url_or_filename, map_location='cpu') 82 | else: 83 | raise RuntimeError('checkpoint url or path is invalid') 84 | state_dict = checkpoint['model'] 85 | 86 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 87 | 88 | for key in list(state_dict.keys()): 89 | if 'crossattention.self.' in key: 90 | new_key0 = key.replace('self','self0') 91 | new_key1 = key.replace('self','self1') 92 | state_dict[new_key0] = state_dict[key] 93 | state_dict[new_key1] = state_dict[key] 94 | elif 'crossattention.output.dense.' in key: 95 | new_key0 = key.replace('dense','dense0') 96 | new_key1 = key.replace('dense','dense1') 97 | state_dict[new_key0] = state_dict[key] 98 | state_dict[new_key1] = state_dict[key] 99 | 100 | msg = model.load_state_dict(state_dict,strict=False) 101 | print('load checkpoint from %s'%url_or_filename) 102 | return model,msg 103 | -------------------------------------------------------------------------------- /models/blip_retrieval.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_Retrieval(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | queue_size = 57600, 19 | momentum = 0.995, 20 | negative_all_rank = False, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | text_width = self.text_encoder.config.hidden_size 37 | 38 | self.vision_proj = nn.Linear(vision_width, embed_dim) 39 | self.text_proj = nn.Linear(text_width, embed_dim) 40 | 41 | self.itm_head = nn.Linear(text_width, 2) 42 | 43 | # create momentum encoders 44 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 45 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 46 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) 47 | self.text_proj_m = nn.Linear(text_width, embed_dim) 48 | 49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 50 | [self.vision_proj,self.vision_proj_m], 51 | [self.text_encoder,self.text_encoder_m], 52 | [self.text_proj,self.text_proj_m], 53 | ] 54 | self.copy_params() 55 | 56 | # create the queue 57 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 58 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 59 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) 60 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) 61 | 62 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 63 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 64 | 65 | self.queue_size = queue_size 66 | self.momentum = momentum 67 | self.temp = nn.Parameter(0.07*torch.ones([])) 68 | 69 | self.negative_all_rank = negative_all_rank 70 | 71 | 72 | def forward(self, image, caption, alpha, idx): 73 | with torch.no_grad(): 74 | self.temp.clamp_(0.001,0.5) 75 | 76 | image_embeds = self.visual_encoder(image) 77 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 78 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 79 | 80 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 81 | return_tensors="pt").to(image.device) 82 | 83 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 84 | return_dict = True, mode = 'text') 85 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 86 | 87 | ###============== Image-text Contrastive Learning ===================### 88 | idx = idx.view(-1,1) 89 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) 90 | pos_idx = torch.eq(idx, idx_all).float() 91 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) 92 | 93 | # get momentum features 94 | with torch.no_grad(): 95 | self._momentum_update() 96 | image_embeds_m = self.visual_encoder_m(image) 97 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 98 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 99 | 100 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 101 | return_dict = True, mode = 'text') 102 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 103 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 104 | 105 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp 106 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp 107 | 108 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 109 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 110 | 111 | sim_i2t = image_feat @ text_feat_m_all / self.temp 112 | sim_t2i = text_feat @ image_feat_m_all / self.temp 113 | 114 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 115 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 116 | 117 | loss_ita = (loss_i2t+loss_t2i)/2 118 | 119 | idxs = concat_all_gather(idx) 120 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) 121 | 122 | ###============== Image-text Matching ===================### 123 | encoder_input_ids = text.input_ids.clone() 124 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 125 | 126 | # forward the positve image-text pair 127 | bs = image.size(0) 128 | output_pos = self.text_encoder(encoder_input_ids, 129 | attention_mask = text.attention_mask, 130 | encoder_hidden_states = image_embeds, 131 | encoder_attention_mask = image_atts, 132 | return_dict = True, 133 | ) 134 | 135 | 136 | if self.negative_all_rank: 137 | # compute sample similarity 138 | with torch.no_grad(): 139 | mask = torch.eq(idx, idxs.t()) 140 | 141 | image_feat_world = concat_all_gather(image_feat) 142 | text_feat_world = concat_all_gather(text_feat) 143 | 144 | sim_i2t = image_feat @ text_feat_world.t() / self.temp 145 | sim_t2i = text_feat @ image_feat_world.t() / self.temp 146 | 147 | weights_i2t = F.softmax(sim_i2t,dim=1) 148 | weights_i2t.masked_fill_(mask, 0) 149 | 150 | weights_t2i = F.softmax(sim_t2i,dim=1) 151 | weights_t2i.masked_fill_(mask, 0) 152 | 153 | image_embeds_world = all_gather_with_grad(image_embeds) 154 | 155 | # select a negative image (from all ranks) for each text 156 | image_embeds_neg = [] 157 | for b in range(bs): 158 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 159 | image_embeds_neg.append(image_embeds_world[neg_idx]) 160 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 161 | 162 | # select a negative text (from all ranks) for each image 163 | input_ids_world = concat_all_gather(encoder_input_ids) 164 | att_mask_world = concat_all_gather(text.attention_mask) 165 | 166 | text_ids_neg = [] 167 | text_atts_neg = [] 168 | for b in range(bs): 169 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 170 | text_ids_neg.append(input_ids_world[neg_idx]) 171 | text_atts_neg.append(att_mask_world[neg_idx]) 172 | 173 | else: 174 | with torch.no_grad(): 175 | mask = torch.eq(idx, idx.t()) 176 | 177 | sim_i2t = image_feat @ text_feat.t() / self.temp 178 | sim_t2i = text_feat @ image_feat.t() / self.temp 179 | 180 | weights_i2t = F.softmax(sim_i2t,dim=1) 181 | weights_i2t.masked_fill_(mask, 0) 182 | 183 | weights_t2i = F.softmax(sim_t2i,dim=1) 184 | weights_t2i.masked_fill_(mask, 0) 185 | 186 | # select a negative image (from same rank) for each text 187 | image_embeds_neg = [] 188 | for b in range(bs): 189 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 190 | image_embeds_neg.append(image_embeds[neg_idx]) 191 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 192 | 193 | # select a negative text (from same rank) for each image 194 | text_ids_neg = [] 195 | text_atts_neg = [] 196 | for b in range(bs): 197 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 198 | text_ids_neg.append(encoder_input_ids[neg_idx]) 199 | text_atts_neg.append(text.attention_mask[neg_idx]) 200 | 201 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 202 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 203 | 204 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 205 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 206 | 207 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 208 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 209 | 210 | output_neg = self.text_encoder(text_ids_all, 211 | attention_mask = text_atts_all, 212 | encoder_hidden_states = image_embeds_all, 213 | encoder_attention_mask = image_atts_all, 214 | return_dict = True, 215 | ) 216 | 217 | 218 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 219 | vl_output = self.itm_head(vl_embeddings) 220 | 221 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 222 | dim=0).to(image.device) 223 | loss_itm = F.cross_entropy(vl_output, itm_labels) 224 | 225 | return loss_ita, loss_itm 226 | 227 | 228 | @torch.no_grad() 229 | def copy_params(self): 230 | for model_pair in self.model_pairs: 231 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 232 | param_m.data.copy_(param.data) # initialize 233 | param_m.requires_grad = False # not update by gradient 234 | 235 | 236 | @torch.no_grad() 237 | def _momentum_update(self): 238 | for model_pair in self.model_pairs: 239 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 240 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 241 | 242 | 243 | @torch.no_grad() 244 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): 245 | # gather keys before updating queue 246 | image_feats = concat_all_gather(image_feat) 247 | text_feats = concat_all_gather(text_feat) 248 | 249 | 250 | batch_size = image_feats.shape[0] 251 | 252 | ptr = int(self.ptr_queue) 253 | assert self.queue_size % batch_size == 0 # for simplicity 254 | 255 | # replace the keys at ptr (dequeue and enqueue) 256 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 257 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 258 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T 259 | ptr = (ptr + batch_size) % self.queue_size # move pointer 260 | 261 | self.ptr_queue[0] = ptr 262 | 263 | 264 | def blip_retrieval(pretrained='',**kwargs): 265 | model = BLIP_Retrieval(**kwargs) 266 | if pretrained: 267 | model,msg = load_checkpoint(model,pretrained) 268 | print("missing keys:") 269 | print(msg.missing_keys) 270 | return model 271 | 272 | 273 | @torch.no_grad() 274 | def concat_all_gather(tensor): 275 | """ 276 | Performs all_gather operation on the provided tensors. 277 | *** Warning ***: torch.distributed.all_gather has no gradient. 278 | """ 279 | tensors_gather = [torch.ones_like(tensor) 280 | for _ in range(torch.distributed.get_world_size())] 281 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 282 | 283 | output = torch.cat(tensors_gather, dim=0) 284 | return output 285 | 286 | 287 | class GatherLayer(torch.autograd.Function): 288 | """ 289 | Gather tensors from all workers with support for backward propagation: 290 | This implementation does not cut the gradients as torch.distributed.all_gather does. 291 | """ 292 | 293 | @staticmethod 294 | def forward(ctx, x): 295 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] 296 | torch.distributed.all_gather(output, x) 297 | return tuple(output) 298 | 299 | @staticmethod 300 | def backward(ctx, *grads): 301 | all_gradients = torch.stack(grads) 302 | torch.distributed.all_reduce(all_gradients) 303 | return all_gradients[torch.distributed.get_rank()] 304 | 305 | 306 | def all_gather_with_grad(tensors): 307 | """ 308 | Performs all_gather operation on the provided tensors. 309 | Graph remains connected for backward grad computation. 310 | """ 311 | # Queue the gathered tensors 312 | world_size = torch.distributed.get_world_size() 313 | # There is no need for reduction in the single-proc case 314 | if world_size == 1: 315 | return tensors 316 | 317 | tensor_all = GatherLayer.apply(tensors) 318 | 319 | return torch.cat(tensor_all, dim=0) 320 | -------------------------------------------------------------------------------- /models/blip_vqa.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel, BertLMHeadModel 2 | from models.blip import create_vit, init_tokenizer, load_checkpoint 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import BertTokenizer 8 | import numpy as np 9 | 10 | class BLIP_VQA(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 480, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | ): 18 | """ 19 | Args: 20 | med_config (str): path for the mixture of encoder-decoder model's configuration file 21 | image_size (int): input image size 22 | vit (str): model size of vision transformer 23 | """ 24 | super().__init__() 25 | 26 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 27 | self.tokenizer = init_tokenizer() 28 | 29 | encoder_config = BertConfig.from_json_file(med_config) 30 | encoder_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 32 | 33 | decoder_config = BertConfig.from_json_file(med_config) 34 | self.text_decoder = BertLMHeadModel(config=decoder_config) 35 | 36 | 37 | def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): 38 | 39 | image_embeds = self.visual_encoder(image) 40 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 41 | 42 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 43 | return_tensors="pt").to(image.device) 44 | question.input_ids[:,0] = self.tokenizer.enc_token_id 45 | 46 | if train: 47 | ''' 48 | n: number of answers for each question 49 | weights: weight for each answer 50 | ''' 51 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) 52 | answer.input_ids[:,0] = self.tokenizer.bos_token_id 53 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) 54 | 55 | question_output = self.text_encoder(question.input_ids, 56 | attention_mask = question.attention_mask, 57 | encoder_hidden_states = image_embeds, 58 | encoder_attention_mask = image_atts, 59 | return_dict = True) 60 | 61 | question_states = [] 62 | question_atts = [] 63 | for b, n in enumerate(n): 64 | question_states += [question_output.last_hidden_state[b]]*n 65 | question_atts += [question.attention_mask[b]]*n 66 | question_states = torch.stack(question_states,0) 67 | question_atts = torch.stack(question_atts,0) 68 | 69 | answer_output = self.text_decoder(answer.input_ids, 70 | attention_mask = answer.attention_mask, 71 | encoder_hidden_states = question_states, 72 | encoder_attention_mask = question_atts, 73 | labels = answer_targets, 74 | return_dict = True, 75 | reduction = 'none', 76 | ) 77 | 78 | loss = weights * answer_output.loss 79 | loss = loss.sum()/image.size(0) 80 | 81 | return loss 82 | 83 | 84 | else: 85 | question_output = self.text_encoder(question.input_ids, 86 | attention_mask = question.attention_mask, 87 | encoder_hidden_states = image_embeds, 88 | encoder_attention_mask = image_atts, 89 | return_dict = True) 90 | 91 | if inference=='generate': 92 | num_beams = 3 93 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) 94 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) 95 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} 96 | 97 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) 98 | 99 | outputs = self.text_decoder.generate(input_ids=bos_ids, 100 | max_length=10, 101 | min_length=1, 102 | num_beams=num_beams, 103 | eos_token_id=self.tokenizer.sep_token_id, 104 | pad_token_id=self.tokenizer.pad_token_id, 105 | **model_kwargs) 106 | 107 | answers = [] 108 | for output in outputs: 109 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 110 | answers.append(answer) 111 | return answers 112 | 113 | elif inference=='rank': 114 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 115 | answer.input_ids, answer.attention_mask, k_test) 116 | return max_ids 117 | 118 | 119 | 120 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 121 | 122 | num_ques = question_states.size(0) 123 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token 124 | 125 | start_output = self.text_decoder(start_ids, 126 | encoder_hidden_states = question_states, 127 | encoder_attention_mask = question_atts, 128 | return_dict = True, 129 | reduction = 'none') 130 | logits = start_output.logits[:,0,:] # first token's logit 131 | 132 | # topk_probs: top-k probability 133 | # topk_ids: [num_question, k] 134 | answer_first_token = answer_ids[:,1] 135 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 136 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 137 | 138 | # answer input: [num_question*k, answer_len] 139 | input_ids = [] 140 | input_atts = [] 141 | for b, topk_id in enumerate(topk_ids): 142 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 143 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 144 | input_ids = torch.cat(input_ids,dim=0) 145 | input_atts = torch.cat(input_atts,dim=0) 146 | 147 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) 148 | 149 | # repeat encoder's output for top-k answers 150 | question_states = tile(question_states, 0, k) 151 | question_atts = tile(question_atts, 0, k) 152 | 153 | output = self.text_decoder(input_ids, 154 | attention_mask = input_atts, 155 | encoder_hidden_states = question_states, 156 | encoder_attention_mask = question_atts, 157 | labels = targets_ids, 158 | return_dict = True, 159 | reduction = 'none') 160 | 161 | log_probs_sum = -output.loss 162 | log_probs_sum = log_probs_sum.view(num_ques,k) 163 | 164 | max_topk_ids = log_probs_sum.argmax(dim=1) 165 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] 166 | 167 | return max_ids 168 | 169 | 170 | def blip_vqa(pretrained='',**kwargs): 171 | model = BLIP_VQA(**kwargs) 172 | if pretrained: 173 | model,msg = load_checkpoint(model,pretrained) 174 | # assert(len(msg.missing_keys)==0) 175 | return model 176 | 177 | 178 | def tile(x, dim, n_tile): 179 | init_dim = x.size(dim) 180 | repeat_idx = [1] * x.dim() 181 | repeat_idx[dim] = n_tile 182 | x = x.repeat(*(repeat_idx)) 183 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 184 | return torch.index_select(x, dim, order_index.to(x.device)) 185 | 186 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download the weights in ./checkpoints beforehand for fast inference 3 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth 4 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth 5 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | from PIL import Image 11 | import torch 12 | from torchvision import transforms 13 | from torchvision.transforms.functional import InterpolationMode 14 | import cog 15 | 16 | from models.blip import blip_decoder 17 | from models.blip_vqa import blip_vqa 18 | from models.blip_itm import blip_itm 19 | 20 | 21 | class Predictor(cog.Predictor): 22 | def setup(self): 23 | self.device = "cuda:0" 24 | 25 | self.models = { 26 | 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth', 27 | image_size=384, vit='base'), 28 | 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth', 29 | image_size=480, vit='base'), 30 | 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth', 31 | image_size=384, vit='base') 32 | } 33 | 34 | @cog.input( 35 | "image", 36 | type=Path, 37 | help="input image", 38 | ) 39 | @cog.input( 40 | "task", 41 | type=str, 42 | default='image_captioning', 43 | options=['image_captioning', 'visual_question_answering', 'image_text_matching'], 44 | help="Choose a task.", 45 | ) 46 | @cog.input( 47 | "question", 48 | type=str, 49 | default=None, 50 | help="Type question for the input image for visual question answering task.", 51 | ) 52 | @cog.input( 53 | "caption", 54 | type=str, 55 | default=None, 56 | help="Type caption for the input image for image text matching task.", 57 | ) 58 | def predict(self, image, task, question, caption): 59 | if task == 'visual_question_answering': 60 | assert question is not None, 'Please type a question for visual question answering task.' 61 | if task == 'image_text_matching': 62 | assert caption is not None, 'Please type a caption for mage text matching task.' 63 | 64 | im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device) 65 | model = self.models[task] 66 | model.eval() 67 | model = model.to(self.device) 68 | 69 | if task == 'image_captioning': 70 | with torch.no_grad(): 71 | caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5) 72 | return 'Caption: ' + caption[0] 73 | 74 | if task == 'visual_question_answering': 75 | with torch.no_grad(): 76 | answer = model(im, question, train=False, inference='generate') 77 | return 'Answer: ' + answer[0] 78 | 79 | # image_text_matching 80 | itm_output = model(im, caption, match_head='itm') 81 | itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1] 82 | itc_score = model(im, caption, match_head='itc') 83 | return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \ 84 | f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.' 85 | 86 | 87 | def load_image(image, image_size, device): 88 | raw_image = Image.open(str(image)).convert('RGB') 89 | 90 | w, h = raw_image.size 91 | 92 | transform = transforms.Compose([ 93 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 96 | ]) 97 | image = transform(raw_image).unsqueeze(0).to(device) 98 | return image 99 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_pretrain import blip_pretrain 26 | import utils 27 | from utils import warmup_lr_schedule, step_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | 30 | def train(model, data_loader, optimizer, epoch, device, config): 31 | # train 32 | model.train() 33 | 34 | metric_logger = utils.MetricLogger(delimiter=" ") 35 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 36 | metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 37 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 38 | metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 39 | 40 | header = 'Train Epoch: [{}]'.format(epoch) 41 | print_freq = 50 42 | 43 | if config['laion_path']: 44 | data_loader.dataset.reload_laion(epoch) 45 | 46 | data_loader.sampler.set_epoch(epoch) 47 | 48 | for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 49 | 50 | if epoch==0: 51 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 52 | 53 | optimizer.zero_grad() 54 | 55 | image = image.to(device,non_blocking=True) 56 | 57 | # ramp up alpha in the first 2 epochs 58 | alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader))) 59 | 60 | loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha) 61 | loss = loss_ita + loss_itm + loss_lm 62 | 63 | loss.backward() 64 | optimizer.step() 65 | 66 | metric_logger.update(loss_ita=loss_ita.item()) 67 | metric_logger.update(loss_itm=loss_itm.item()) 68 | metric_logger.update(loss_lm=loss_lm.item()) 69 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 70 | 71 | 72 | # gather the stats from all processes 73 | metric_logger.synchronize_between_processes() 74 | print("Averaged stats:", metric_logger.global_avg()) 75 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 76 | 77 | 78 | def main(args, config): 79 | utils.init_distributed_mode(args) 80 | 81 | device = torch.device(args.device) 82 | 83 | # fix the seed for reproducibility 84 | seed = args.seed + utils.get_rank() 85 | torch.manual_seed(seed) 86 | np.random.seed(seed) 87 | random.seed(seed) 88 | cudnn.benchmark = True 89 | 90 | #### Dataset #### 91 | print("Creating dataset") 92 | datasets = [create_dataset('pretrain', config, min_scale=0.2)] 93 | print('number of training samples: %d'%len(datasets[0])) 94 | 95 | num_tasks = utils.get_world_size() 96 | global_rank = utils.get_rank() 97 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 98 | 99 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0] 100 | 101 | #### Model #### 102 | print("Creating model") 103 | model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 104 | vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size']) 105 | 106 | model = model.to(device) 107 | 108 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 109 | 110 | start_epoch = 0 111 | if args.checkpoint: 112 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 113 | state_dict = checkpoint['model'] 114 | model.load_state_dict(state_dict) 115 | 116 | optimizer.load_state_dict(checkpoint['optimizer']) 117 | start_epoch = checkpoint['epoch']+1 118 | print('resume checkpoint from %s'%args.checkpoint) 119 | 120 | model_without_ddp = model 121 | if args.distributed: 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 123 | model_without_ddp = model.module 124 | 125 | print("Start training") 126 | start_time = time.time() 127 | for epoch in range(start_epoch, config['max_epoch']): 128 | 129 | step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate']) 130 | 131 | train_stats = train(model, data_loader, optimizer, epoch, device, config) 132 | if utils.is_main_process(): 133 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 134 | 'epoch': epoch, 135 | } 136 | save_obj = { 137 | 'model': model_without_ddp.state_dict(), 138 | 'optimizer': optimizer.state_dict(), 139 | 'config': config, 140 | 'epoch': epoch, 141 | } 142 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 143 | 144 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 145 | f.write(json.dumps(log_stats) + "\n") 146 | 147 | dist.barrier() 148 | 149 | total_time = time.time() - start_time 150 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 151 | print('Training time {}'.format(total_time_str)) 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('--config', default='./configs/pretrain.yaml') 157 | parser.add_argument('--output_dir', default='output/Pretrain') 158 | parser.add_argument('--checkpoint', default='') 159 | parser.add_argument('--evaluate', action='store_true') 160 | parser.add_argument('--device', default='cuda') 161 | parser.add_argument('--seed', default=42, type=int) 162 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 163 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 164 | parser.add_argument('--distributed', default=True, type=bool) 165 | args = parser.parse_args() 166 | 167 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 168 | 169 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 170 | 171 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 172 | 173 | main(args, config) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | transformers==4.15.0 3 | fairscale==0.4.4 4 | pycocoevalcap 5 | -------------------------------------------------------------------------------- /train_caption.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | from data.utils import save_result, coco_caption_eval 30 | 31 | def train(model, data_loader, optimizer, epoch, device): 32 | # train 33 | model.train() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 37 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 38 | header = 'Train Caption Epoch: [{}]'.format(epoch) 39 | print_freq = 50 40 | 41 | for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 42 | image = image.to(device) 43 | 44 | loss = model(image, caption) 45 | 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | metric_logger.update(loss=loss.item()) 51 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 52 | 53 | # gather the stats from all processes 54 | metric_logger.synchronize_between_processes() 55 | print("Averaged stats:", metric_logger.global_avg()) 56 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 57 | 58 | 59 | @torch.no_grad() 60 | def evaluate(model, data_loader, device, config): 61 | # evaluate 62 | model.eval() 63 | 64 | metric_logger = utils.MetricLogger(delimiter=" ") 65 | header = 'Caption generation:' 66 | print_freq = 10 67 | 68 | result = [] 69 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 70 | 71 | image = image.to(device) 72 | 73 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 74 | min_length=config['min_length']) 75 | 76 | for caption, img_id in zip(captions, image_id): 77 | result.append({"image_id": img_id.item(), "caption": caption}) 78 | 79 | return result 80 | 81 | 82 | def main(args, config): 83 | utils.init_distributed_mode(args) 84 | 85 | device = torch.device(args.device) 86 | 87 | # fix the seed for reproducibility 88 | seed = args.seed + utils.get_rank() 89 | torch.manual_seed(seed) 90 | np.random.seed(seed) 91 | random.seed(seed) 92 | cudnn.benchmark = True 93 | 94 | #### Dataset #### 95 | print("Creating captioning dataset") 96 | train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config) 97 | 98 | if args.distributed: 99 | num_tasks = utils.get_world_size() 100 | global_rank = utils.get_rank() 101 | samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank) 102 | else: 103 | samplers = [None, None, None] 104 | 105 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, 106 | batch_size=[config['batch_size']]*3,num_workers=[4,4,4], 107 | is_trains=[True, False, False], collate_fns=[None,None,None]) 108 | 109 | #### Model #### 110 | print("Creating model") 111 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 112 | vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 113 | prompt=config['prompt']) 114 | 115 | model = model.to(device) 116 | 117 | model_without_ddp = model 118 | if args.distributed: 119 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 120 | model_without_ddp = model.module 121 | 122 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 123 | 124 | best = 0 125 | best_epoch = 0 126 | 127 | print("Start training") 128 | start_time = time.time() 129 | for epoch in range(0, config['max_epoch']): 130 | if not args.evaluate: 131 | if args.distributed: 132 | train_loader.sampler.set_epoch(epoch) 133 | 134 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 135 | 136 | train_stats = train(model, train_loader, optimizer, epoch, device) 137 | 138 | val_result = evaluate(model_without_ddp, val_loader, device, config) 139 | val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id') 140 | 141 | test_result = evaluate(model_without_ddp, test_loader, device, config) 142 | test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id') 143 | 144 | if utils.is_main_process(): 145 | coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val') 146 | coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test') 147 | 148 | if args.evaluate: 149 | log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()}, 150 | **{f'test_{k}': v for k, v in coco_test.eval.items()}, 151 | } 152 | with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f: 153 | f.write(json.dumps(log_stats) + "\n") 154 | else: 155 | save_obj = { 156 | 'model': model_without_ddp.state_dict(), 157 | 'optimizer': optimizer.state_dict(), 158 | 'config': config, 159 | 'epoch': epoch, 160 | } 161 | 162 | if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best: 163 | best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] 164 | best_epoch = epoch 165 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 166 | 167 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 168 | **{f'val_{k}': v for k, v in coco_val.eval.items()}, 169 | **{f'test_{k}': v for k, v in coco_test.eval.items()}, 170 | 'epoch': epoch, 171 | 'best_epoch': best_epoch, 172 | } 173 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 174 | f.write(json.dumps(log_stats) + "\n") 175 | 176 | if args.evaluate: 177 | break 178 | dist.barrier() 179 | 180 | total_time = time.time() - start_time 181 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 182 | print('Training time {}'.format(total_time_str)) 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--config', default='./configs/caption_coco.yaml') 188 | parser.add_argument('--output_dir', default='output/Caption_coco') 189 | parser.add_argument('--evaluate', action='store_true') 190 | parser.add_argument('--device', default='cuda') 191 | parser.add_argument('--seed', default=42, type=int) 192 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 193 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 194 | parser.add_argument('--distributed', default=True, type=bool) 195 | args = parser.parse_args() 196 | 197 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 198 | 199 | args.result_dir = os.path.join(args.output_dir, 'result') 200 | 201 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 202 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 203 | 204 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 205 | 206 | main(args, config) -------------------------------------------------------------------------------- /train_nlvr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | import json 18 | import pickle 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.utils.data import DataLoader 24 | import torch.backends.cudnn as cudnn 25 | import torch.distributed as dist 26 | 27 | from models.blip_nlvr import blip_nlvr 28 | 29 | import utils 30 | from utils import cosine_lr_schedule, warmup_lr_schedule 31 | from data import create_dataset, create_sampler, create_loader 32 | 33 | def train(model, data_loader, optimizer, epoch, device, config): 34 | # train 35 | model.train() 36 | 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 39 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 40 | 41 | header = 'Train Epoch: [{}]'.format(epoch) 42 | print_freq = 50 43 | step_size = 10 44 | 45 | for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 46 | 47 | images = torch.cat([image0, image1], dim=0) 48 | images, targets = images.to(device), targets.to(device) 49 | 50 | loss = model(images, text, targets=targets, train=True) 51 | 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 57 | metric_logger.update(loss=loss.item()) 58 | 59 | # gather the stats from all processes 60 | metric_logger.synchronize_between_processes() 61 | print("Averaged stats:", metric_logger.global_avg()) 62 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 63 | 64 | 65 | @torch.no_grad() 66 | def evaluate(model, data_loader, device, config): 67 | # test 68 | model.eval() 69 | 70 | metric_logger = utils.MetricLogger(delimiter=" ") 71 | 72 | header = 'Evaluation:' 73 | print_freq = 50 74 | 75 | for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header): 76 | images = torch.cat([image0, image1], dim=0) 77 | images, targets = images.to(device), targets.to(device) 78 | 79 | prediction = model(images, text, targets=targets, train=False) 80 | 81 | _, pred_class = prediction.max(1) 82 | accuracy = (targets==pred_class).sum() / targets.size(0) 83 | 84 | metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0)) 85 | 86 | # gather the stats from all processes 87 | metric_logger.synchronize_between_processes() 88 | 89 | print("Averaged stats:", metric_logger.global_avg()) 90 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 91 | 92 | 93 | 94 | def main(args, config): 95 | utils.init_distributed_mode(args) 96 | 97 | device = torch.device(args.device) 98 | 99 | # fix the seed for reproducibility 100 | seed = args.seed + utils.get_rank() 101 | torch.manual_seed(seed) 102 | np.random.seed(seed) 103 | random.seed(seed) 104 | cudnn.benchmark = True 105 | 106 | #### Dataset #### 107 | print("Creating dataset") 108 | datasets = create_dataset('nlvr', config) 109 | 110 | if args.distributed: 111 | num_tasks = utils.get_world_size() 112 | global_rank = utils.get_rank() 113 | samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank) 114 | else: 115 | samplers = [None, None, None] 116 | 117 | batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']] 118 | train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size, 119 | num_workers=[4,4,4],is_trains=[True,False,False], 120 | collate_fns=[None,None,None]) 121 | 122 | #### Model #### 123 | print("Creating model") 124 | model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'], 125 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 126 | 127 | model = model.to(device) 128 | 129 | model_without_ddp = model 130 | if args.distributed: 131 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 132 | model_without_ddp = model.module 133 | 134 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 135 | 136 | print("Start training") 137 | start_time = time.time() 138 | best = 0 139 | best_epoch = 0 140 | 141 | for epoch in range(0, config['max_epoch']): 142 | if not args.evaluate: 143 | if args.distributed: 144 | train_loader.sampler.set_epoch(epoch) 145 | 146 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 147 | 148 | train_stats = train(model, train_loader, optimizer, epoch, device, config) 149 | 150 | val_stats = evaluate(model, val_loader, device, config) 151 | test_stats = evaluate(model, test_loader, device, config) 152 | 153 | if utils.is_main_process(): 154 | if args.evaluate: 155 | log_stats = {**{f'val_{k}': v for k, v in val_stats.items()}, 156 | **{f'test_{k}': v for k, v in test_stats.items()}, 157 | } 158 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 159 | f.write(json.dumps(log_stats) + "\n") 160 | 161 | else: 162 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 163 | **{f'val_{k}': v for k, v in val_stats.items()}, 164 | **{f'test_{k}': v for k, v in test_stats.items()}, 165 | 'epoch': epoch, 166 | } 167 | 168 | if float(val_stats['acc'])>best: 169 | save_obj = { 170 | 'model': model_without_ddp.state_dict(), 171 | 'optimizer': optimizer.state_dict(), 172 | 'config': config, 173 | 'epoch': epoch, 174 | } 175 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 176 | best = float(val_stats['acc']) 177 | best_epoch = epoch 178 | 179 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 180 | f.write(json.dumps(log_stats) + "\n") 181 | if args.evaluate: 182 | break 183 | 184 | dist.barrier() 185 | 186 | if utils.is_main_process(): 187 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 188 | f.write("best epoch: %d"%best_epoch) 189 | 190 | total_time = time.time() - start_time 191 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 192 | print('Training time {}'.format(total_time_str)) 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('--config', default='./configs/nlvr.yaml') 198 | parser.add_argument('--output_dir', default='output/NLVR') 199 | parser.add_argument('--evaluate', action='store_true') 200 | parser.add_argument('--device', default='cuda') 201 | parser.add_argument('--seed', default=42, type=int) 202 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 203 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 204 | parser.add_argument('--distributed', default=True, type=bool) 205 | args = parser.parse_args() 206 | 207 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 208 | 209 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 210 | 211 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 212 | 213 | main(args, config) -------------------------------------------------------------------------------- /train_retrieval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_retrieval import blip_retrieval 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | 30 | 31 | def train(model, data_loader, optimizer, epoch, device, config): 32 | # train 33 | model.train() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 37 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 38 | metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 39 | header = 'Train Epoch: [{}]'.format(epoch) 40 | print_freq = 50 41 | 42 | for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 43 | image = image.to(device,non_blocking=True) 44 | idx = idx.to(device,non_blocking=True) 45 | 46 | if epoch>0: 47 | alpha = config['alpha'] 48 | else: 49 | alpha = config['alpha']*min(1,i/len(data_loader)) 50 | 51 | loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx) 52 | loss = loss_ita + loss_itm 53 | 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | 58 | metric_logger.update(loss_itm=loss_itm.item()) 59 | metric_logger.update(loss_ita=loss_ita.item()) 60 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 61 | 62 | # gather the stats from all processes 63 | metric_logger.synchronize_between_processes() 64 | print("Averaged stats:", metric_logger.global_avg()) 65 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 66 | 67 | 68 | @torch.no_grad() 69 | def evaluation(model, data_loader, device, config): 70 | # test 71 | model.eval() 72 | 73 | metric_logger = utils.MetricLogger(delimiter=" ") 74 | header = 'Evaluation:' 75 | 76 | print('Computing features for evaluation...') 77 | start_time = time.time() 78 | 79 | texts = data_loader.dataset.text 80 | num_text = len(texts) 81 | text_bs = 256 82 | text_ids = [] 83 | text_embeds = [] 84 | text_atts = [] 85 | for i in range(0, num_text, text_bs): 86 | text = texts[i: min(num_text, i+text_bs)] 87 | text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device) 88 | text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') 89 | text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:])) 90 | text_embeds.append(text_embed) 91 | text_ids.append(text_input.input_ids) 92 | text_atts.append(text_input.attention_mask) 93 | 94 | text_embeds = torch.cat(text_embeds,dim=0) 95 | text_ids = torch.cat(text_ids,dim=0) 96 | text_atts = torch.cat(text_atts,dim=0) 97 | text_ids[:,0] = model.tokenizer.enc_token_id 98 | 99 | image_feats = [] 100 | image_embeds = [] 101 | for image, img_id in data_loader: 102 | image = image.to(device) 103 | image_feat = model.visual_encoder(image) 104 | image_embed = model.vision_proj(image_feat[:,0,:]) 105 | image_embed = F.normalize(image_embed,dim=-1) 106 | 107 | image_feats.append(image_feat.cpu()) 108 | image_embeds.append(image_embed) 109 | 110 | image_feats = torch.cat(image_feats,dim=0) 111 | image_embeds = torch.cat(image_embeds,dim=0) 112 | 113 | sims_matrix = image_embeds @ text_embeds.t() 114 | score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device) 115 | 116 | num_tasks = utils.get_world_size() 117 | rank = utils.get_rank() 118 | step = sims_matrix.size(0)//num_tasks + 1 119 | start = rank*step 120 | end = min(sims_matrix.size(0),start+step) 121 | 122 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 123 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 124 | 125 | encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device) 126 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device) 127 | output = model.text_encoder(text_ids[topk_idx], 128 | attention_mask = text_atts[topk_idx], 129 | encoder_hidden_states = encoder_output, 130 | encoder_attention_mask = encoder_att, 131 | return_dict = True, 132 | ) 133 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 134 | score_matrix_i2t[start+i,topk_idx] = score + topk_sim 135 | 136 | sims_matrix = sims_matrix.t() 137 | score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device) 138 | 139 | step = sims_matrix.size(0)//num_tasks + 1 140 | start = rank*step 141 | end = min(sims_matrix.size(0),start+step) 142 | 143 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 144 | 145 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 146 | encoder_output = image_feats[topk_idx].to(device) 147 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device) 148 | output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1), 149 | attention_mask = text_atts[start+i].repeat(config['k_test'],1), 150 | encoder_hidden_states = encoder_output, 151 | encoder_attention_mask = encoder_att, 152 | return_dict = True, 153 | ) 154 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 155 | score_matrix_t2i[start+i,topk_idx] = score + topk_sim 156 | 157 | if args.distributed: 158 | dist.barrier() 159 | torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM) 160 | torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM) 161 | 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('Evaluation time {}'.format(total_time_str)) 165 | 166 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 167 | 168 | 169 | 170 | @torch.no_grad() 171 | def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): 172 | 173 | #Images->Text 174 | ranks = np.zeros(scores_i2t.shape[0]) 175 | for index,score in enumerate(scores_i2t): 176 | inds = np.argsort(score)[::-1] 177 | # Score 178 | rank = 1e20 179 | for i in img2txt[index]: 180 | tmp = np.where(inds == i)[0][0] 181 | if tmp < rank: 182 | rank = tmp 183 | ranks[index] = rank 184 | 185 | # Compute metrics 186 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 187 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 188 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 189 | 190 | #Text->Images 191 | ranks = np.zeros(scores_t2i.shape[0]) 192 | 193 | for index,score in enumerate(scores_t2i): 194 | inds = np.argsort(score)[::-1] 195 | ranks[index] = np.where(inds == txt2img[index])[0][0] 196 | 197 | # Compute metrics 198 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 199 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 200 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 201 | 202 | tr_mean = (tr1 + tr5 + tr10) / 3 203 | ir_mean = (ir1 + ir5 + ir10) / 3 204 | r_mean = (tr_mean + ir_mean) / 2 205 | 206 | eval_result = {'txt_r1': tr1, 207 | 'txt_r5': tr5, 208 | 'txt_r10': tr10, 209 | 'txt_r_mean': tr_mean, 210 | 'img_r1': ir1, 211 | 'img_r5': ir5, 212 | 'img_r10': ir10, 213 | 'img_r_mean': ir_mean, 214 | 'r_mean': r_mean} 215 | return eval_result 216 | 217 | 218 | def main(args, config): 219 | utils.init_distributed_mode(args) 220 | 221 | device = torch.device(args.device) 222 | 223 | # fix the seed for reproducibility 224 | seed = args.seed + utils.get_rank() 225 | torch.manual_seed(seed) 226 | np.random.seed(seed) 227 | random.seed(seed) 228 | cudnn.benchmark = True 229 | 230 | #### Dataset #### 231 | print("Creating retrieval dataset") 232 | train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config) 233 | 234 | if args.distributed: 235 | num_tasks = utils.get_world_size() 236 | global_rank = utils.get_rank() 237 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None] 238 | else: 239 | samplers = [None, None, None] 240 | 241 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, 242 | batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2, 243 | num_workers=[4,4,4], 244 | is_trains=[True, False, False], 245 | collate_fns=[None,None,None]) 246 | 247 | 248 | #### Model #### 249 | print("Creating model") 250 | model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 251 | vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 252 | queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank']) 253 | 254 | model = model.to(device) 255 | 256 | model_without_ddp = model 257 | if args.distributed: 258 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 259 | model_without_ddp = model.module 260 | 261 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 262 | 263 | best = 0 264 | best_epoch = 0 265 | 266 | print("Start training") 267 | start_time = time.time() 268 | 269 | for epoch in range(0, config['max_epoch']): 270 | if not args.evaluate: 271 | if args.distributed: 272 | train_loader.sampler.set_epoch(epoch) 273 | 274 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 275 | 276 | train_stats = train(model, train_loader, optimizer, epoch, device, config) 277 | 278 | score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config) 279 | score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config) 280 | 281 | if utils.is_main_process(): 282 | 283 | val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt) 284 | print(val_result) 285 | 286 | if val_result['r_mean']>best: 287 | save_obj = { 288 | 'model': model_without_ddp.state_dict(), 289 | 'optimizer': optimizer.state_dict(), 290 | 'config': config, 291 | 'epoch': epoch, 292 | } 293 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 294 | best = val_result['r_mean'] 295 | best_epoch = epoch 296 | 297 | test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt) 298 | print(test_result) 299 | 300 | if args.evaluate: 301 | log_stats = {**{f'val_{k}': v for k, v in val_result.items()}, 302 | **{f'test_{k}': v for k, v in test_result.items()}, 303 | } 304 | with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f: 305 | f.write(json.dumps(log_stats) + "\n") 306 | else: 307 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 308 | **{f'val_{k}': v for k, v in val_result.items()}, 309 | **{f'test_{k}': v for k, v in test_result.items()}, 310 | 'epoch': epoch, 311 | 'best_epoch': best_epoch, 312 | } 313 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 314 | f.write(json.dumps(log_stats) + "\n") 315 | 316 | if args.evaluate: 317 | break 318 | 319 | dist.barrier() 320 | torch.cuda.empty_cache() 321 | 322 | total_time = time.time() - start_time 323 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 324 | print('Training time {}'.format(total_time_str)) 325 | 326 | 327 | if __name__ == '__main__': 328 | parser = argparse.ArgumentParser() 329 | parser.add_argument('--config', default='./configs/retrieval_flickr.yaml') 330 | parser.add_argument('--output_dir', default='output/Retrieval_flickr') 331 | parser.add_argument('--evaluate', action='store_true') 332 | parser.add_argument('--device', default='cuda') 333 | parser.add_argument('--seed', default=42, type=int) 334 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 335 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 336 | parser.add_argument('--distributed', default=True, type=bool) 337 | args = parser.parse_args() 338 | 339 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 340 | 341 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 342 | 343 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 344 | 345 | main(args, config) -------------------------------------------------------------------------------- /train_vqa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.utils.data import DataLoader 22 | import torch.backends.cudnn as cudnn 23 | import torch.distributed as dist 24 | 25 | from models.blip_vqa import blip_vqa 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | from data.vqa_dataset import vqa_collate_fn 30 | from data.utils import save_result 31 | 32 | 33 | def train(model, data_loader, optimizer, epoch, device): 34 | # train 35 | model.train() 36 | 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 39 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 40 | 41 | header = 'Train Epoch: [{}]'.format(epoch) 42 | print_freq = 50 43 | 44 | for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True) 46 | 47 | loss = model(image, question, answer, train=True, n=n, weights=weights) 48 | 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | metric_logger.update(loss=loss.item()) 54 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 55 | 56 | # gather the stats from all processes 57 | metric_logger.synchronize_between_processes() 58 | print("Averaged stats:", metric_logger.global_avg()) 59 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 60 | 61 | 62 | @torch.no_grad() 63 | def evaluation(model, data_loader, device, config) : 64 | # test 65 | model.eval() 66 | 67 | metric_logger = utils.MetricLogger(delimiter=" ") 68 | header = 'Generate VQA test result:' 69 | print_freq = 50 70 | 71 | result = [] 72 | 73 | if config['inference']=='rank': 74 | answer_list = data_loader.dataset.answer_list 75 | answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device) 76 | answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id 77 | 78 | for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 79 | image = image.to(device,non_blocking=True) 80 | 81 | if config['inference']=='generate': 82 | answers = model(image, question, train=False, inference='generate') 83 | 84 | for answer, ques_id in zip(answers, question_id): 85 | ques_id = int(ques_id.item()) 86 | result.append({"question_id":ques_id, "answer":answer}) 87 | 88 | elif config['inference']=='rank': 89 | answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test']) 90 | 91 | for ques_id, answer_id in zip(question_id, answer_ids): 92 | result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]}) 93 | 94 | return result 95 | 96 | 97 | def main(args, config): 98 | utils.init_distributed_mode(args) 99 | 100 | device = torch.device(args.device) 101 | 102 | # fix the seed for reproducibility 103 | seed = args.seed + utils.get_rank() 104 | torch.manual_seed(seed) 105 | np.random.seed(seed) 106 | random.seed(seed) 107 | cudnn.benchmark = True 108 | 109 | #### Dataset #### 110 | print("Creating vqa datasets") 111 | datasets = create_dataset('vqa', config) 112 | 113 | if args.distributed: 114 | num_tasks = utils.get_world_size() 115 | global_rank = utils.get_rank() 116 | samplers = create_sampler(datasets, [True, False], num_tasks, global_rank) 117 | else: 118 | samplers = [None, None] 119 | 120 | train_loader, test_loader = create_loader(datasets,samplers, 121 | batch_size=[config['batch_size_train'],config['batch_size_test']], 122 | num_workers=[4,4],is_trains=[True, False], 123 | collate_fns=[vqa_collate_fn,None]) 124 | #### Model #### 125 | print("Creating model") 126 | model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'], 127 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 128 | 129 | model = model.to(device) 130 | 131 | model_without_ddp = model 132 | if args.distributed: 133 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 134 | model_without_ddp = model.module 135 | 136 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 137 | 138 | best = 0 139 | best_epoch = 0 140 | 141 | print("Start training") 142 | start_time = time.time() 143 | for epoch in range(0, config['max_epoch']): 144 | if not args.evaluate: 145 | if args.distributed: 146 | train_loader.sampler.set_epoch(epoch) 147 | 148 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 149 | 150 | train_stats = train(model, train_loader, optimizer, epoch, device) 151 | 152 | else: 153 | break 154 | 155 | if utils.is_main_process(): 156 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 157 | 'epoch': epoch, 158 | } 159 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 160 | f.write(json.dumps(log_stats) + "\n") 161 | 162 | save_obj = { 163 | 'model': model_without_ddp.state_dict(), 164 | 'optimizer': optimizer.state_dict(), 165 | 'config': config, 166 | 'epoch': epoch, 167 | } 168 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 169 | 170 | dist.barrier() 171 | 172 | vqa_result = evaluation(model_without_ddp, test_loader, device, config) 173 | result_file = save_result(vqa_result, args.result_dir, 'vqa_result') 174 | 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print('Training time {}'.format(total_time_str)) 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--config', default='./configs/vqa.yaml') 184 | parser.add_argument('--output_dir', default='output/VQA') 185 | parser.add_argument('--evaluate', action='store_true') 186 | parser.add_argument('--device', default='cuda') 187 | parser.add_argument('--seed', default=42, type=int) 188 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 189 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 190 | parser.add_argument('--distributed', default=True, type=bool) 191 | args = parser.parse_args() 192 | 193 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 194 | 195 | args.result_dir = os.path.join(args.output_dir, 'result') 196 | 197 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 198 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 199 | 200 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 201 | 202 | main(args, config) -------------------------------------------------------------------------------- /transform/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 3 | """Decay the learning rate""" 4 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] = lr 7 | 8 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 9 | """Warmup the learning rate""" 10 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | 14 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 15 | """Decay the learning rate""" 16 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | 20 | import numpy as np 21 | import io 22 | import os 23 | import time 24 | from collections import defaultdict, deque 25 | import datetime 26 | 27 | import torch 28 | import torch.distributed as dist 29 | 30 | class SmoothedValue(object): 31 | """Track a series of values and provide access to smoothed values over a 32 | window or the global series average. 33 | """ 34 | 35 | def __init__(self, window_size=20, fmt=None): 36 | if fmt is None: 37 | fmt = "{median:.4f} ({global_avg:.4f})" 38 | self.deque = deque(maxlen=window_size) 39 | self.total = 0.0 40 | self.count = 0 41 | self.fmt = fmt 42 | 43 | def update(self, value, n=1): 44 | self.deque.append(value) 45 | self.count += n 46 | self.total += value * n 47 | 48 | def synchronize_between_processes(self): 49 | """ 50 | Warning: does not synchronize the deque! 51 | """ 52 | if not is_dist_avail_and_initialized(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 55 | dist.barrier() 56 | dist.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self): 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self): 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self): 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self): 81 | return self.deque[-1] 82 | 83 | def __str__(self): 84 | return self.fmt.format( 85 | median=self.median, 86 | avg=self.avg, 87 | global_avg=self.global_avg, 88 | max=self.max, 89 | value=self.value) 90 | 91 | 92 | class MetricLogger(object): 93 | def __init__(self, delimiter="\t"): 94 | self.meters = defaultdict(SmoothedValue) 95 | self.delimiter = delimiter 96 | 97 | def update(self, **kwargs): 98 | for k, v in kwargs.items(): 99 | if isinstance(v, torch.Tensor): 100 | v = v.item() 101 | assert isinstance(v, (float, int)) 102 | self.meters[k].update(v) 103 | 104 | def __getattr__(self, attr): 105 | if attr in self.meters: 106 | return self.meters[attr] 107 | if attr in self.__dict__: 108 | return self.__dict__[attr] 109 | raise AttributeError("'{}' object has no attribute '{}'".format( 110 | type(self).__name__, attr)) 111 | 112 | def __str__(self): 113 | loss_str = [] 114 | for name, meter in self.meters.items(): 115 | loss_str.append( 116 | "{}: {}".format(name, str(meter)) 117 | ) 118 | return self.delimiter.join(loss_str) 119 | 120 | def global_avg(self): 121 | loss_str = [] 122 | for name, meter in self.meters.items(): 123 | loss_str.append( 124 | "{}: {:.4f}".format(name, meter.global_avg) 125 | ) 126 | return self.delimiter.join(loss_str) 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | def add_meter(self, name, meter): 133 | self.meters[name] = meter 134 | 135 | def log_every(self, iterable, print_freq, header=None): 136 | i = 0 137 | if not header: 138 | header = '' 139 | start_time = time.time() 140 | end = time.time() 141 | iter_time = SmoothedValue(fmt='{avg:.4f}') 142 | data_time = SmoothedValue(fmt='{avg:.4f}') 143 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 144 | log_msg = [ 145 | header, 146 | '[{0' + space_fmt + '}/{1}]', 147 | 'eta: {eta}', 148 | '{meters}', 149 | 'time: {time}', 150 | 'data: {data}' 151 | ] 152 | if torch.cuda.is_available(): 153 | log_msg.append('max mem: {memory:.0f}') 154 | log_msg = self.delimiter.join(log_msg) 155 | MB = 1024.0 * 1024.0 156 | for obj in iterable: 157 | data_time.update(time.time() - end) 158 | yield obj 159 | iter_time.update(time.time() - end) 160 | if i % print_freq == 0 or i == len(iterable) - 1: 161 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 162 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 163 | if torch.cuda.is_available(): 164 | print(log_msg.format( 165 | i, len(iterable), eta=eta_string, 166 | meters=str(self), 167 | time=str(iter_time), data=str(data_time), 168 | memory=torch.cuda.max_memory_allocated() / MB)) 169 | else: 170 | print(log_msg.format( 171 | i, len(iterable), eta=eta_string, 172 | meters=str(self), 173 | time=str(iter_time), data=str(data_time))) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print('{} Total time: {} ({:.4f} s / it)'.format( 179 | header, total_time_str, total_time / len(iterable))) 180 | 181 | 182 | class AttrDict(dict): 183 | def __init__(self, *args, **kwargs): 184 | super(AttrDict, self).__init__(*args, **kwargs) 185 | self.__dict__ = self 186 | 187 | 188 | def compute_acc(logits, label, reduction='mean'): 189 | ret = (torch.argmax(logits, dim=1) == label).float() 190 | if reduction == 'none': 191 | return ret.detach() 192 | elif reduction == 'mean': 193 | return ret.mean().item() 194 | 195 | def compute_n_params(model, return_str=True): 196 | tot = 0 197 | for p in model.parameters(): 198 | w = 1 199 | for x in p.shape: 200 | w *= x 201 | tot += w 202 | if return_str: 203 | if tot >= 1e6: 204 | return '{:.1f}M'.format(tot / 1e6) 205 | else: 206 | return '{:.1f}K'.format(tot / 1e3) 207 | else: 208 | return tot 209 | 210 | def setup_for_distributed(is_master): 211 | """ 212 | This function disables printing when not in master process 213 | """ 214 | import builtins as __builtin__ 215 | builtin_print = __builtin__.print 216 | 217 | def print(*args, **kwargs): 218 | force = kwargs.pop('force', False) 219 | if is_master or force: 220 | builtin_print(*args, **kwargs) 221 | 222 | __builtin__.print = print 223 | 224 | 225 | def is_dist_avail_and_initialized(): 226 | if not dist.is_available(): 227 | return False 228 | if not dist.is_initialized(): 229 | return False 230 | return True 231 | 232 | 233 | def get_world_size(): 234 | if not is_dist_avail_and_initialized(): 235 | return 1 236 | return dist.get_world_size() 237 | 238 | 239 | def get_rank(): 240 | if not is_dist_avail_and_initialized(): 241 | return 0 242 | return dist.get_rank() 243 | 244 | 245 | def is_main_process(): 246 | return get_rank() == 0 247 | 248 | 249 | def save_on_master(*args, **kwargs): 250 | if is_main_process(): 251 | torch.save(*args, **kwargs) 252 | 253 | 254 | def init_distributed_mode(args): 255 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 256 | args.rank = int(os.environ["RANK"]) 257 | args.world_size = int(os.environ['WORLD_SIZE']) 258 | args.gpu = int(os.environ['LOCAL_RANK']) 259 | elif 'SLURM_PROCID' in os.environ: 260 | args.rank = int(os.environ['SLURM_PROCID']) 261 | args.gpu = args.rank % torch.cuda.device_count() 262 | else: 263 | print('Not using distributed mode') 264 | args.distributed = False 265 | return 266 | 267 | args.distributed = True 268 | 269 | torch.cuda.set_device(args.gpu) 270 | args.dist_backend = 'nccl' 271 | print('| distributed init (rank {}, word {}): {}'.format( 272 | args.rank, args.world_size, args.dist_url), flush=True) 273 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 274 | world_size=args.world_size, rank=args.rank) 275 | torch.distributed.barrier() 276 | setup_for_distributed(args.rank == 0) 277 | 278 | --------------------------------------------------------------------------------