├── README.md ├── assets ├── .gitkeep └── SugaFormer.png ├── configs ├── .gitkeep └── vaw │ ├── eval_fs.sh │ ├── eval_zs.sh │ ├── train_fs.sh │ └── train_zs.sh ├── data └── .gitkeep ├── datasets ├── __init__.py ├── transforms.py ├── vaw.py └── vaw_eval.py ├── engine.py ├── main.py ├── models ├── __init__.py ├── backbone.py ├── lavis │ ├── __init__.py │ ├── common │ │ ├── config.py │ │ ├── dist_utils.py │ │ ├── gradcam.py │ │ ├── logger.py │ │ ├── optims.py │ │ ├── registry.py │ │ └── utils.py │ ├── configs │ │ └── default.yaml │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blip2_models │ │ │ ├── Qformer.py │ │ │ ├── __init__.py │ │ │ ├── blip2.py │ │ │ └── blip2_qformer.py │ │ ├── blip_models │ │ │ ├── __init__.py │ │ │ ├── blip.py │ │ │ └── blip_outputs.py │ │ ├── clip_vit.py │ │ ├── eva_vit.py │ │ ├── med.py │ │ └── vit.py │ ├── processors │ │ ├── __init__.py │ │ ├── base_processor.py │ │ ├── blip_processors.py │ │ ├── clip_processors.py │ │ └── randaugment.py │ └── runners │ │ ├── __init__.py │ │ ├── runner_base.py │ │ └── runner_iter.py ├── position_encoding.py ├── sugaformer.py └── transformer.py ├── requirements.txt ├── tools ├── launch.py └── run_dist_launch.sh └── util ├── __init__.py └── misc.py /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Super-class guided Transformer for Zero-Shot Attribute Classification

3 | 4 |

Sehyung Kim*, Chanhyeong Yang*, Jihwan Park, Taehoon Song, Hyunwoo J. Kim†.

5 |

AAAI 2025

6 | 7 |

8 | 9 | 10 | 11 |

12 |

13 | 14 | ---- 15 | 16 | ![SugaFormer](assets/SugaFormer.png) 17 | 18 | This is the official implementation of AAAI 2025 paper "Super-class guided Transformer for Zero-Shot Attribute Classification" 19 | 20 | ---- 21 | 22 | ## Environment Setting 23 | ```bash 24 | git clone https://github.com/mlvlab/SugaFormer.git 25 | cd SugaFormer 26 | conda create -n sugaformer python==3.9 27 | conda activate sugaformer 28 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ---- 33 | 34 | ## Dataset Preparation 35 | 36 | To run experiments for VAW, you need both the images from the Visual Genome dataset and the annotation files. Follow the steps below: 37 | 38 | 1. Download the Visual Genome images from the [link](https://homes.cs.washington.edu/~ranjay/visualgenome/index.html). 39 | 2. Download the annotation files for VAW experiments from the [link](https://drive.google.com/drive/folders/1qW3HkMcdLHnsCDXn00TyFD4rAlErZRRW?usp=drive_link). 40 | 41 | 42 | ### Organize the Data 43 | After downloading the Visual Genome images and annotation files, organize them into the following directory structure: 44 | 45 | ```bash 46 | 47 | data/ 48 | └── vaw/ 49 | ├── images/ 50 | │ ├── VG_100K/ 51 | │ └── VG_100K_2/ 52 | │ 53 | └── annotations/ 54 | ├── train.json 55 | ├── test.json 56 | ├── ... 57 | 58 | ``` 59 | 60 | ## Training 61 | 62 | ### VAW Fully-Supervised 63 | Train the model in the fully-supervised setting: 64 | ```bash 65 | ./configs/vaw/train_fs.sh 66 | ``` 67 | 68 | ### VAW Zero-Shot (base2novel) 69 | Train the model in the zero-shot setting: 70 | ```bash 71 | ./configs/vaw/train_zs.sh 72 | ``` 73 | ## Evaluation 74 | 75 | ### VAW Fully-Supervised 76 | Evaluate the model in the fully-supervised setting: 77 | ```bash 78 | ./configs/vaw/eval_fs.sh 79 | ``` 80 | ### VAW Zero-Shot (base2novel) 81 | Evaluate the model in the zero-shot setting: 82 | ```bash 83 | ./configs/vaw/eval_zs.sh 84 | ``` 85 | 86 | ## Acknowledgements 87 | This repository is built upon the following works: 88 | 89 | * [DETR (Facebook Research)](https://github.com/facebookresearch/detr): The codebase we built upon and the foundation for our base model. 90 | 91 | * [LAVIS (Salesforce)](https://github.com/salesforce/LAVIS): Pre-trained Vision-Language Models (BLIP2) that we utilized for feature extraction and knowledge transfer. 92 | 93 | ## Contact 94 | If you have any questions, please create an issue on this repository or contact at shkim129@korea.ac.kr. 95 | 96 | ## Citation 97 | If you find our work interesting, please consider giving a ⭐ and citation. 98 | ```bibtex 99 | @article{kim2025super, 100 | title={Super-class guided Transformer for Zero-Shot Attribute Classification}, 101 | author={Kim, Sehyung and Yang, Chanhyeong and Park, Jihwan and Song, Taehoon and Kim, Hyunwoo J}, 102 | journal={arXiv preprint arXiv:2501.05728}, 103 | year={2025} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /assets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/assets/.gitkeep -------------------------------------------------------------------------------- /assets/SugaFormer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/assets/SugaFormer.png -------------------------------------------------------------------------------- /configs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/configs/.gitkeep -------------------------------------------------------------------------------- /configs/vaw/eval_fs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | export GPUS_PER_NODE=4 5 | 6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \ 7 | python -u main.py \ 8 | --dataset_file vaw \ 9 | --mode supervised \ 10 | --eval \ 11 | --zrse \ 12 | --dec_layers 3 \ 13 | --batch_size 4 \ 14 | --pretrained exps/vaw/supervised/checkpoint.pth \ 15 | --output_dir exps/vaw/supervised/ 16 | 17 | -------------------------------------------------------------------------------- /configs/vaw/eval_zs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | export GPUS_PER_NODE=4 5 | 6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \ 7 | python -u main.py \ 8 | --dataset_file vaw \ 9 | --mode zero_shot \ 10 | --eval \ 11 | --zrse \ 12 | --dec_layers 3 \ 13 | --batch_size 4 \ 14 | --pretrained exps/vaw/zero_shot/checkpoint.pth \ 15 | --output_dir exps/vaw/zero_shot/ 16 | -------------------------------------------------------------------------------- /configs/vaw/train_fs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | export GPUS_PER_NODE=4 5 | 6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \ 7 | python -u main.py \ 8 | --dataset_file vaw \ 9 | --mode supervised \ 10 | --scr_coef 2 \ 11 | --att_loss_coef 1 \ 12 | --dec_layers 3 \ 13 | --epochs 15 \ 14 | --lr_drop 13 \ 15 | --batch_size 4 \ 16 | --output_dir exps/vaw/supervised/ -------------------------------------------------------------------------------- /configs/vaw/train_zs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | export GPUS_PER_NODE=4 5 | 6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \ 7 | python -u main.py \ 8 | --dataset_file vaw \ 9 | --mode zero_shot \ 10 | --scr_coef 2 \ 11 | --att_loss_coef 1 \ 12 | --dec_layers 3 \ 13 | --epochs 9 \ 14 | --batch_size 4 \ 15 | --output_dir exps/vaw/zero_shot/ 16 | 17 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/data/.gitkeep -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vaw import build as build_vaw 2 | def build_dataset(image_set, args): 3 | 4 | if args.dataset_file == 'vaw': 5 | return build_vaw(image_set, args) 6 | 7 | raise ValueError(f'dataset {args.dataset_file} not supported') 8 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as F 2 | import torchvision.transforms as T 3 | import numpy as np 4 | import random 5 | import torch 6 | import PIL 7 | import cv2 8 | 9 | def crop(image, target, region): 10 | cropped_image = F.crop(image, *region) 11 | 12 | target = target.copy() 13 | i, j, h, w = region 14 | 15 | target["size"] = torch.tensor([h, w]) 16 | 17 | fields = ["labels", "area", "iscrowd"] 18 | if "pos_att_classes" in target.keys(): 19 | fields.append("pos_att_classes") 20 | if "neg_att_classes" in target.keys(): 21 | fields.append("neg_att_classes") 22 | if "boxes" in target: 23 | boxes = target["boxes"] 24 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 25 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 26 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 27 | cropped_boxes = cropped_boxes.clamp(min=0) 28 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 29 | target["boxes"] = cropped_boxes.reshape(-1, 4) 30 | target["area"] = area 31 | fields.append("boxes") 32 | 33 | if "masks" in target: 34 | # FIXME should we update the area here if there are no boxes? 35 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 36 | fields.append("masks") 37 | 38 | # remove elements for which the boxes or masks that have zero area 39 | if "boxes" in target or "masks" in target: 40 | # favor boxes selection when defining which elements to keep 41 | # this is compatible with previous implementation 42 | if "boxes" in target: 43 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 44 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 45 | else: 46 | keep = target['masks'].flatten(1).any(1) 47 | 48 | for field in fields: 49 | target[field] = target[field][keep] 50 | 51 | return cropped_image, target 52 | 53 | 54 | def hflip(image, target): 55 | flipped_image = F.hflip(image) 56 | 57 | w, h = image.size 58 | 59 | target = target.copy() 60 | if "boxes" in target: 61 | boxes = target["boxes"] 62 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 63 | target["boxes"] = boxes 64 | 65 | if "masks" in target: 66 | target['masks'] = target['masks'].flip(-1) 67 | 68 | return flipped_image, target 69 | 70 | 71 | def resize(image, target, size, max_size=None): 72 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 73 | w, h = image_size 74 | if max_size is not None: 75 | min_original_size = float(min((w, h))) 76 | max_original_size = float(max((w, h))) 77 | if max_original_size / min_original_size * size > max_size: 78 | size = int(round(max_size * min_original_size / max_original_size)) 79 | 80 | if (w <= h and w == size) or (h <= w and h == size): 81 | return (h, w) 82 | 83 | if w < h: 84 | ow = size 85 | oh = int(size * h / w) 86 | else: 87 | oh = size 88 | ow = int(size * w / h) 89 | 90 | return (oh, ow) 91 | 92 | def get_size(image_size, size, max_size=None): 93 | if isinstance(size, (list, tuple)): 94 | return size[::-1] 95 | else: 96 | return get_size_with_aspect_ratio(image_size, size, max_size) 97 | 98 | size = get_size(image.size, size, max_size) 99 | rescaled_image = F.resize(image, size) 100 | 101 | if target is None: 102 | return rescaled_image, None 103 | 104 | return rescaled_image, target 105 | 106 | 107 | def pad(image, target, padding): 108 | # assumes that we only pad on the bottom right corners 109 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 110 | if target is None: 111 | return padded_image, None 112 | target = target.copy() 113 | # should we do something wrt the original size? 114 | target["size"] = torch.tensor(padded_image[::-1]) 115 | if "masks" in target: 116 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 117 | return padded_image, target 118 | 119 | 120 | class RandomCrop(object): 121 | def __init__(self, size): 122 | self.size = size 123 | 124 | def __call__(self, img, target): 125 | region = T.RandomCrop.get_params(img, self.size) 126 | return crop(img, target, region) 127 | 128 | 129 | class RandomSizeCrop(object): 130 | def __init__(self, min_size: int, max_size: int): 131 | self.min_size = min_size 132 | self.max_size = max_size 133 | 134 | def __call__(self, img: PIL.Image.Image, target: dict): 135 | w = random.randint(self.min_size, min(img.width, self.max_size)) 136 | h = random.randint(self.min_size, min(img.height, self.max_size)) 137 | region = T.RandomCrop.get_params(img, [h, w]) 138 | return crop(img, target, region) 139 | 140 | 141 | class CenterCrop(object): 142 | def __init__(self, size): 143 | self.size = size 144 | 145 | def __call__(self, img, target): 146 | image_width, image_height = img.size 147 | crop_height, crop_width = self.size 148 | crop_top = int(round((image_height - crop_height) / 2.)) 149 | crop_left = int(round((image_width - crop_width) / 2.)) 150 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 151 | 152 | 153 | class RandomHorizontalFlip(object): 154 | def __init__(self, p=0.5): 155 | self.p = p 156 | 157 | def __call__(self, img, target): 158 | if random.random() < self.p: 159 | return hflip(img, target) 160 | return img, target 161 | 162 | 163 | class RandomResize(object): 164 | def __init__(self, sizes, max_size=None): 165 | assert isinstance(sizes, (list, tuple)) 166 | self.sizes = sizes 167 | self.max_size = max_size 168 | 169 | def __call__(self, img, target=None): 170 | size = random.choice(self.sizes) 171 | return resize(img, target, size, self.max_size) 172 | 173 | 174 | class RandomPad(object): 175 | def __init__(self, max_pad): 176 | self.max_pad = max_pad 177 | 178 | def __call__(self, img, target): 179 | pad_x = random.randint(0, self.max_pad) 180 | pad_y = random.randint(0, self.max_pad) 181 | return pad(img, target, (pad_x, pad_y)) 182 | 183 | 184 | class RandomSelect(object): 185 | def __init__(self, transforms1, transforms2, p=0.5): 186 | self.transforms1 = transforms1 187 | self.transforms2 = transforms2 188 | self.p = p 189 | 190 | def __call__(self, img, target): 191 | if random.random() < self.p: 192 | return self.transforms1(img, target) 193 | return self.transforms2(img, target) 194 | 195 | 196 | class ToTensor(object): 197 | def __call__(self, img): 198 | return F.to_tensor(img) 199 | 200 | 201 | class RandomErasing(object): 202 | 203 | def __init__(self, *args, **kwargs): 204 | self.eraser = T.RandomErasing(*args, **kwargs) 205 | 206 | def __call__(self, img, target): 207 | return self.eraser(img), target 208 | 209 | 210 | class Normalize(object): 211 | def __init__(self, mean, std): 212 | self.mean = mean 213 | self.std = std 214 | 215 | def __call__(self, image, target=None): 216 | image = F.normalize(image, mean=self.mean, std=self.std) 217 | if target is None: 218 | return image 219 | return image, target 220 | 221 | 222 | class Compose(object): 223 | def __init__(self, transforms): 224 | self.transforms = transforms 225 | 226 | def __call__(self, image, target=None): 227 | for t in self.transforms: 228 | if isinstance(t, T.Normalize): # Apply only to the image 229 | image = t(image) 230 | else: 231 | image, target = t(image, target) 232 | return image, target 233 | 234 | 235 | class ColorJitter(object): 236 | def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0): 237 | self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue) 238 | 239 | def __call__(self, img, target): 240 | return self.color_jitter(img), target 241 | 242 | 243 | def polygon_to_mask(polygon, height, width): 244 | mask = np.zeros((height, width), dtype=np.uint8) 245 | polygon = np.array(polygon, dtype=np.int32).reshape((-1, 2)) 246 | cv2.fillPoly(mask, [polygon], 1) # Fill the polygon area with 1 247 | return mask 248 | -------------------------------------------------------------------------------- /datasets/vaw.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import datasets.transforms as T 3 | from pathlib import Path 4 | import torch.utils.data 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | import json 9 | 10 | class VAW_Dataloader(torch.utils.data.Dataset): 11 | def __init__(self, img_set, img_folder, anno_file, attribute_index, transforms, args=None): 12 | self.att_ids = list(self.load_json(attribute_index).values()) 13 | self.annotations = self.load_json(anno_file) 14 | self.img_set = img_set 15 | self.img_folder = img_folder 16 | self.transforms = transforms 17 | self.sc_token_path = args.use_scr 18 | 19 | def load_json(self, file_path): 20 | with open(file_path, 'r') as f: 21 | return json.load(f) 22 | 23 | def num_attributes(self): 24 | return len(self.att_ids) 25 | 26 | def __len__(self): 27 | return len(self.annotations) 28 | 29 | def __getitem__(self, idx): 30 | img_anno = self.annotations[idx] 31 | img_id = str(img_anno['image_id']) 32 | object_names = img_anno['object_name'] 33 | file_dir = self.get_file_path(img_anno['file_name']) 34 | img = Image.open(self.img_folder / file_dir).convert('RGB') 35 | boxes, orig_size, cropped_masks, keep = self.process_boxes_and_mask(img_anno, img) 36 | pos_att_classes, neg_att_classes = self.create_attribute_classes(img_anno) 37 | pos_att_classes = pos_att_classes[keep] 38 | neg_att_classes = neg_att_classes[keep] 39 | if self.img_set == 'train': 40 | sc_mask_output = self.get_sc_mask_output(img_id) 41 | sc_mask_output = sc_mask_output[keep] 42 | target = self.create_target_dict(img_anno, boxes, pos_att_classes, neg_att_classes, orig_size, object_names, sc_mask_output) 43 | elif self.img_set == 'test': 44 | target = self.create_target_dict(img_anno, boxes, pos_att_classes, neg_att_classes, orig_size, object_names) 45 | transformed_img, target = self.apply_transforms(img, target) 46 | crop_imgs = self.crop_and_normalize_boxes(img, boxes) 47 | return transformed_img, crop_imgs, cropped_masks, target 48 | 49 | def get_file_path(self, file_name): 50 | return file_name.split('/')[-2] + '/' + file_name.split('/')[-1] 51 | 52 | def process_boxes_and_mask(self, img_anno, img): 53 | boxes = torch.as_tensor(img_anno['boxes'], dtype=torch.float32).reshape(-1, 4) 54 | w, h = img.size 55 | orig_size = torch.as_tensor([h, w]) 56 | boxes[:, 2:] += boxes[:, :2] 57 | boxes[:, 0::2].clamp_(min=0, max=w) 58 | boxes[:, 1::2].clamp_(min=0, max=h) 59 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 60 | boxes = boxes[keep] 61 | mask_path = img_anno['mask'] 62 | masks = torch.from_numpy(np.load(mask_path)) 63 | masks = masks[keep] 64 | cropped_masks = self.crop_masks(masks, boxes, orig_size) 65 | return boxes, orig_size, cropped_masks, keep 66 | 67 | def create_attribute_classes(self, img_anno): 68 | pos_att_classes = torch.zeros((len(img_anno['boxes']), self.num_attributes()), dtype=torch.float32) 69 | neg_att_classes = torch.zeros((len(img_anno['boxes']), self.num_attributes()), dtype=torch.float32) 70 | for b, pos_id in zip(pos_att_classes, img_anno['pos_att_id']): 71 | b[pos_id] = 1 72 | for b, neg_id in zip(neg_att_classes, img_anno['neg_att_id']): 73 | b[neg_id] = 1 74 | return pos_att_classes, neg_att_classes 75 | 76 | def get_sc_mask_output(self, img_id): 77 | sc_mask_output = torch.load(self.sc_token_path + img_id + '.pt', map_location='cpu') 78 | return sc_mask_output 79 | 80 | def create_target_dict(self, img_anno, boxes, pos_att_classes, neg_att_classes, orig_size, object_names, sc_mask_output=None): 81 | target = { 82 | 'boxes': boxes, 83 | 'pos_att_classes': pos_att_classes, 84 | 'neg_att_classes': neg_att_classes, 85 | 'img_id': img_anno['image_id'], 86 | 'orig_size': orig_size, 87 | 'obj_names': object_names, 88 | } 89 | if sc_mask_output is not None: 90 | target['sc_token_output'] = sc_mask_output 91 | return target 92 | 93 | def apply_transforms(self, img, target): 94 | if self.transforms: 95 | transformed_img, _ = self.transforms(img, target) 96 | return transformed_img, target 97 | 98 | def crop_and_normalize_boxes(self, img, boxes): 99 | crop_imgs = [] 100 | for box in boxes: 101 | box = box.int() 102 | cropped_img = img.crop((box[0].item(), box[1].item(), box[2].item(), box[3].item())) 103 | cropped_img = T.ToTensor()(cropped_img) 104 | cropped_img, _ = T.resize(cropped_img, None, size=(224, 224)) 105 | cropped_img = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(cropped_img) 106 | crop_imgs.append(cropped_img) 107 | return torch.stack(crop_imgs) 108 | 109 | def crop_masks(self, masks, boxes, orig_size): 110 | resized_masks = [] 111 | for mask, box in zip(masks, boxes): 112 | crop_resized_mask = self.mask_crop_resize(mask.unsqueeze(0), orig_size, box) 113 | resized_masks.append(crop_resized_mask) 114 | return torch.cat(resized_masks) 115 | 116 | def mask_crop_resize(self, mask, orig_size, box, img_size=224): 117 | orig_h, orig_w = orig_size[0].item(), orig_size[1].item() 118 | box[0::2].clamp_(min=0, max=orig_w) 119 | box[1::2].clamp_(min=0, max=orig_h) 120 | cropped_mask = mask[:, :orig_h, :orig_w][:, int(box[1]):int(box[3]), int(box[0]):int(box[2])] 121 | if True in cropped_mask: 122 | resized_mask = F.interpolate(cropped_mask.unsqueeze(0).type(torch.float), size=(img_size, img_size), mode='bicubic', align_corners=False) 123 | resized_mask = (resized_mask > 0.5).float() 124 | else: 125 | resized_mask = torch.ones((1, 1, img_size, img_size)) 126 | return resized_mask 127 | 128 | def make_vaw_transforms(image_set, img_size=224): 129 | 130 | def transform(img, target): 131 | img = T.ToTensor()(img) 132 | img, _ = T.resize(img, None, size=(img_size, img_size)) 133 | img = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) 134 | return img, target 135 | 136 | if image_set in ['train', 'test']: 137 | return transform 138 | 139 | def build(image_set, args): 140 | root = Path(args.vaw_path) 141 | assert root.exists(), f'Provided VAW path {root} does not exist' 142 | PATHS = { 143 | 'train': (root / 'images', root / 'annotations' / 'train.json'), 144 | 'val': (root / 'images', root / 'annotations' / 'test.json'), 145 | 'test': (root / 'images', root / 'annotations' / 'test.json') 146 | } 147 | 148 | attribute_index = root / 'annotations' / 'attribute_index.json' 149 | 150 | img_folder, anno_file = PATHS[image_set] 151 | assert img_folder.exists(), f"Image folder {img_folder} does not exist" 152 | assert anno_file.exists(), f"Annotation file {anno_file} does not exist" 153 | 154 | transforms = make_vaw_transforms(image_set) 155 | dataset = VAW_Dataloader(image_set, img_folder, anno_file, attribute_index, transforms, args) 156 | 157 | return dataset 158 | -------------------------------------------------------------------------------- /datasets/vaw_eval.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import average_precision_score 2 | import numpy as np 3 | import torch 4 | import json 5 | 6 | K = 15 7 | def top_K_values(array): 8 | """Keeps only topK largest values in array. 9 | """ 10 | indexes = np.argpartition(array, -K, axis=-1)[-K:] 11 | A = set(indexes) 12 | B = set(list(range(array.shape[0]))) 13 | B -= A 14 | array[list(B)]=0 15 | return array 16 | 17 | class Evaluator(object): 18 | def __init__( 19 | self, 20 | fpath_attr2idx, 21 | fpath_attr_headtail, 22 | threshold=0.5, 23 | exclude_atts=[] 24 | ): 25 | """Initializes evaluator for attribute prediction on VAW dataset. 26 | Args: 27 | - fpath_attr2idx: path to attribute class index file. 28 | - fpath_attr_headtail: path to attribute head/mid/tail categorization file. 29 | - threshold: positive/negative threshold (for Accuracy metric). 30 | - exclude_atts: any attribute classes to be excluded from evaluation. 31 | """ 32 | 33 | # Read file that maps from id to attribute name. 34 | with open(fpath_attr2idx, 'r') as f: 35 | self.attr2idx = json.load(f) 36 | self.idx2attr = {v:k for k, v in self.attr2idx.items()} 37 | 38 | # Read file that shows whether attribute is head/mid/tail. 39 | with open(fpath_attr_headtail, 'r') as f: 40 | self.attribute_head_tail = json.load(f) 41 | 42 | self.n_class = len(self.attr2idx) 43 | self.exclude_atts = exclude_atts 44 | self.threshold = threshold 45 | 46 | # Cache metric score for each class. 47 | self.score = {} 48 | self.score_topk = {} 49 | 50 | def _clear_cache(self): 51 | self.score = {} 52 | self.score_topk = {} 53 | 54 | def get_attr_head_tail(self, attr): 55 | """Finds whether attribute is in head/medium/tail group. 56 | """ 57 | for group, L in self.attribute_head_tail.items(): 58 | if attr in L: 59 | return group 60 | assert False, f"Can't find head/medium/tail group for {attr}" 61 | 62 | def evaluate( 63 | self, 64 | pred, 65 | gt_label, 66 | threshold_type='threshold' 67 | ): 68 | """Evaluates a prediction matrix against groundtruth label. 69 | Args: 70 | - pred: prediction matrix [n_instance, n_class]. 71 | pred[i,j] is the j-th attribute score of instance i-th. 72 | These scores should be from 0 -> 1. 73 | - gt_label: groundtruth label matrix [n_instances, n_class]. 74 | gt_label[i,j] = 1 if instance i is positively labeled with 75 | attribute j, = 0 if it is negatively labeled, and = 2 if 76 | it is unlabeled. 77 | - threshold_type: 'threshold' or 'topk'. 78 | Determines positive vs. negative prediction. 79 | """ 80 | self.pred = pred 81 | self.gt_label = gt_label 82 | self.n_instance = self.gt_label.shape[0] 83 | 84 | # For topK metrics, we keep a version of the prediction matrix that sets 85 | # non-topK elements as 0 and topK elements as 1. 86 | P_topk = self.pred.detach().cpu().numpy().copy() 87 | P_topk = np.apply_along_axis(top_K_values, 1, P_topk) 88 | P_topk[P_topk > 0] = 1 89 | self.pred_topk = P_topk 90 | all_groups = ['all', 'head', 'medium', 'tail'] 91 | groups_overall = { 92 | k: GroupClassMetric(metric_type='overall') 93 | for k in all_groups 94 | } 95 | groups_per_class = { 96 | k: GroupClassMetric(metric_type='per-class') 97 | for k in all_groups 98 | } 99 | class_metric_dict = {} 100 | for i_class in range(self.n_class): 101 | attr = self.idx2attr[i_class] 102 | if attr in self.exclude_atts: 103 | continue 104 | 105 | class_metric = self.get_score_class(i_class, threshold_type=threshold_type) 106 | class_metric_dict[i_class] = class_metric 107 | 108 | # Add to 'all' group. 109 | groups_overall['all'].add_class(class_metric) 110 | groups_per_class['all'].add_class(class_metric) 111 | 112 | # Add to head/medium/tail group. 113 | imbalance_group = self.get_attr_head_tail(attr) 114 | groups_overall[imbalance_group].add_class(class_metric) 115 | groups_per_class[imbalance_group].add_class(class_metric) 116 | 117 | 118 | 119 | # Aggregate final scores. 120 | # For overall, we're interested in F1. 121 | # For per-class, we're interested in mean AP, mean recall, mean balanced accuracy. 122 | scores_overall = {} 123 | for group_name, group in groups_overall.items(): 124 | scores_overall[group_name] = { 125 | 'f1': group.get_f1(), 126 | 'precision': group.get_precision(), 127 | 'recall': group.get_recall(), 128 | 'tnr': group.get_tnr(), 129 | } 130 | scores_per_class = {} 131 | for group_name, group in groups_per_class.items(): 132 | scores_per_class[group_name] = { 133 | 'ap': group.get_ap(), 134 | 'f1': group.get_f1(), 135 | 'precision': group.get_precision(), 136 | 'recall': group.get_recall(), 137 | 'bacc': group.get_bacc() 138 | } 139 | 140 | return scores_per_class 141 | 142 | def get_score_class(self, i_class, threshold_type='threshold'): 143 | """Computes all metrics for a given class. 144 | Args: 145 | - i_class: class index. 146 | - threshold_type: 'topk' or 'threshold'. This determines how a 147 | prediction is positive or negative. 148 | """ 149 | if threshold_type == 'threshold': 150 | score = self.score 151 | else: 152 | score = self.score_topk 153 | if i_class in score: 154 | return score[i_class] 155 | 156 | if threshold_type == 'threshold': 157 | pred = self.pred[:,i_class] 158 | 159 | else: 160 | pred = self.pred_topk[:,i_class] 161 | gt_label = self.gt_label[:,i_class] 162 | 163 | # Find instances that are explicitly labeled (either positive or negative). 164 | mask_labeled = (gt_label < 2) 165 | if mask_labeled.sum() == 0: 166 | # None of the instances have label for this class. 167 | # assert False, f"0 labeled instances for attribute {self.idx2attr[i_class]}" 168 | pass 169 | else: 170 | # Select ony the labeled ones. 171 | pred = pred[mask_labeled] 172 | gt_label = gt_label[mask_labeled] 173 | 174 | if threshold_type == 'threshold': 175 | # Only computes AP when threshold_type is 'threshold'. This is because when 176 | # threshold_type is 'topk', pred is a binary matrix. 177 | ap = average_precision_score(gt_label, pred) 178 | 179 | # Make pred into binary matrix. 180 | pred[pred > self.threshold] = 1 181 | pred[pred <= self.threshold] = 0 182 | 183 | class_metric = SingleClassMetric(pred, gt_label) 184 | if threshold_type == 'threshold': 185 | class_metric.ap = ap 186 | 187 | # Cache results. 188 | score[i_class] = class_metric 189 | 190 | return class_metric 191 | 192 | 193 | class GroupClassMetric(object): 194 | def __init__(self, metric_type): 195 | """This class computes all metrics for a group of attributes. 196 | Args: 197 | - metric_type: 'overall' or 'per-class'. 198 | """ 199 | self.metric_type = metric_type 200 | 201 | if metric_type == 'overall': 202 | # Keep track of all stats. 203 | self.true_pos = 0 204 | self.false_pos = 0 205 | self.true_neg = 0 206 | self.false_neg = 0 207 | self.n_pos = 0 208 | self.n_neg = 0 209 | else: 210 | self.metric = { 211 | name: [] 212 | for name in ['recall', 'tnr', 'acc', 'bacc', 'precision', 'f1', 'ap'] 213 | } 214 | 215 | def add_class(self, class_metric): 216 | """Adds computed metrics of a class into this group. 217 | """ 218 | if self.metric_type == 'overall': 219 | self.true_pos += class_metric.true_pos 220 | self.false_pos += class_metric.false_pos 221 | self.true_neg += class_metric.true_neg 222 | self.false_neg += class_metric.false_neg 223 | self.n_pos += class_metric.n_pos 224 | self.n_neg += class_metric.n_neg 225 | else: 226 | self.metric['recall'].append(class_metric.get_recall()) 227 | self.metric['tnr'].append(class_metric.get_tnr()) 228 | self.metric['acc'].append(class_metric.get_acc()) 229 | self.metric['bacc'].append(class_metric.get_bacc()) 230 | self.metric['precision'].append(class_metric.get_precision()) 231 | self.metric['f1'].append(class_metric.get_f1()) 232 | self.metric['ap'].append(class_metric.ap) 233 | 234 | def get_recall(self): 235 | """Computes recall. 236 | """ 237 | if self.metric_type == 'overall': 238 | n_pos_pred = self.true_pos + self.false_pos 239 | if n_pos_pred == 0: 240 | # Model makes 0 positive prediction. 241 | # This is a special case: we fall back to precision = 1 and recall = 0. 242 | return 0 243 | 244 | if self.n_pos > 0: 245 | return self.true_pos / self.n_pos 246 | return -1 247 | else: 248 | if -1 not in self.metric['recall']: 249 | return np.mean(self.metric['recall']) 250 | return -1 251 | 252 | def get_tnr(self): 253 | """Computes true negative rate. 254 | """ 255 | if self.metric_type == 'overall': 256 | if self.n_neg > 0: 257 | return self.true_neg / self.n_neg 258 | return -1 259 | else: 260 | if -1 not in self.metric['tnr']: 261 | return np.mean(self.metric['tnr']) 262 | return -1 263 | 264 | def get_acc(self): 265 | """Computes accuracy. 266 | """ 267 | if self.metric_type == 'overall': 268 | if self.n_pos + self.n_neg > 0: 269 | return (self.true_pos + self.true_neg) / (self.n_pos + self.n_neg) 270 | return -1 271 | else: 272 | if -1 not in self.metric['acc']: 273 | return np.mean(self.metric['acc']) 274 | return -1 275 | 276 | def get_bacc(self): 277 | """Computes balanced accuracy. 278 | """ 279 | if self.metric_type == 'overall': 280 | recall = self.get_recall() 281 | tnr = self.get_tnr() 282 | if recall == -1 or tnr == -1: 283 | return -1 284 | return (recall + tnr) / 2.0 285 | else: 286 | if -1 not in self.metric['bacc']: 287 | return np.mean(self.metric['bacc']) 288 | return -1 289 | 290 | def get_precision(self): 291 | """Computes precision. 292 | """ 293 | if self.metric_type == 'overall': 294 | n_pos_pred = self.true_pos + self.false_pos 295 | if n_pos_pred == 0: 296 | # Model makes 0 positive prediction. 297 | # This is a special case: we fall back to precision = 1 and recall = 0. 298 | return 1 299 | return self.true_pos / n_pos_pred 300 | else: 301 | if -1 not in self.metric['precision']: 302 | return np.mean(self.metric['precision']) 303 | return -1 304 | 305 | def get_f1(self): 306 | """Computes F1. 307 | """ 308 | if self.metric_type == 'overall': 309 | recall = self.get_recall() 310 | precision = self.get_precision() 311 | if precision + recall > 0: 312 | return 2 * precision * recall / (precision + recall) 313 | return 0 314 | else: 315 | if -1 not in self.metric['f1']: 316 | return np.mean(self.metric['f1']) 317 | return -1 318 | 319 | def get_ap(self): 320 | """Computes mAP. 321 | """ 322 | assert self.metric_type == 'per-class' 323 | return np.mean(self.metric['ap']) 324 | 325 | 326 | class SingleClassMetric(object): 327 | def __init__(self, pred, gt_label): 328 | """This class computes all metrics for a single attribute. 329 | Args: 330 | - pred: np.array of shape [n_instance] -> binary prediction. 331 | - gt_label: np.array of shape [n_instance] -> groundtruth binary label. 332 | """ 333 | if pred is None or gt_label is None: 334 | self.true_pos = 0 335 | self.false_pos = 0 336 | self.true_neg = 0 337 | self.false_neg = 0 338 | self.n_pos = 0 339 | self.n_neg = 0 340 | self.ap = -1 341 | return 342 | 343 | self.true_pos = ((gt_label == 1) & (pred == 1)).sum() 344 | self.false_pos = ((gt_label == 0) & (pred == 1)).sum() 345 | self.true_neg = ((gt_label == 0) & (pred == 0)).sum() 346 | self.false_neg = ((gt_label == 1) & (pred == 0)).sum() 347 | 348 | # Number of groundtruth positives & negatives. 349 | self.n_pos = self.true_pos + self.false_neg 350 | self.n_neg = self.false_pos + self.true_neg 351 | 352 | # AP score. 353 | self.ap = -1 354 | 355 | def get_recall(self): 356 | """Computes recall. 357 | """ 358 | n_pos_pred = self.true_pos + self.false_pos 359 | if n_pos_pred == 0: 360 | # Model makes 0 positive prediction. 361 | # This is a special case: we fall back to precision = 1 and recall = 0. 362 | return 0 363 | 364 | if self.n_pos > 0: 365 | return self.true_pos / self.n_pos 366 | return -1 367 | 368 | def get_tnr(self): 369 | """Computes true negative rate. 370 | """ 371 | if self.n_neg > 0: 372 | return self.true_neg / self.n_neg 373 | return -1 374 | 375 | def get_acc(self): 376 | """Computes accuracy. 377 | """ 378 | if self.n_pos + self.n_neg > 0: 379 | return (self.true_pos + self.true_neg) / (self.n_pos + self.n_neg) 380 | return -1 381 | 382 | def get_bacc(self): 383 | """Computes balanced accuracy. 384 | """ 385 | recall = self.get_recall() 386 | tnr = self.get_tnr() 387 | if recall == -1 or tnr == -1: 388 | return -1 389 | return (recall + tnr) / 2.0 390 | 391 | def get_precision(self): 392 | """Computes precision. 393 | """ 394 | n_pos_pred = self.true_pos + self.false_pos 395 | if n_pos_pred == 0: 396 | # Model makes 0 positive prediction. 397 | # This is a special case: we fall back to precision = 1 and recall = 0. 398 | return 1 399 | return self.true_pos / n_pos_pred 400 | 401 | def get_f1(self): 402 | """Computes F1. 403 | """ 404 | recall = self.get_recall() 405 | precision = self.get_precision() 406 | 407 | if precision + recall > 0: 408 | return 2 * precision * recall / (precision + recall) 409 | return 0 410 | 411 | 412 | def preprocess_pos_neg(targets): 413 | """ 414 | Preprocess positive and negative attribute indices to create ground truth labels. 415 | 416 | Args: 417 | targets (list of dict): List of target dictionaries, each containing: 418 | - 'pos_att_classes' (torch.Tensor): Positive attribute indices (binary format). 419 | - 'neg_att_classes' (torch.Tensor): Negative attribute indices (binary format). 420 | 421 | Returns: 422 | list: List of ground truth labels where: 423 | 1 indicates positive, 424 | 0 indicates negative, 425 | 2 indicates neutral/unlabeled. 426 | """ 427 | pos_atts = torch.cat([target['pos_att_classes'] for target in targets], dim=0) 428 | neg_atts = torch.cat([target['neg_att_classes'] for target in targets], dim=0) 429 | 430 | gts = [] 431 | for pos_att, neg_att in zip(pos_atts, neg_atts): 432 | gt = 2 * np.ones(pos_att.unsqueeze(0).shape, dtype=np.int64) 433 | gt[:, pos_att.cpu() == 1] = np.int64(1) 434 | gt[:, neg_att.cpu() == 1] = np.int64(0) 435 | gts.extend(gt) 436 | 437 | return gts -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | from datasets.vaw_eval import Evaluator, preprocess_pos_neg 2 | from torch.cuda.amp import autocast 3 | from typing import Iterable 4 | import util.misc as utils 5 | import numpy as np 6 | import torch 7 | import math 8 | import json 9 | import sys 10 | 11 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0, args=None): 12 | model.train() 13 | criterion.train() 14 | metric_logger = utils.MetricLogger(delimiter=" ") 15 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 16 | header = 'Epoch: [{}]'.format(epoch) 17 | print_freq = 10 18 | for samples, crop_samples, mask_samples, targets in metric_logger.log_every(data_loader, print_freq, header): 19 | samples = samples.tensors.to(device) 20 | crop_samples = [crop_sample.to(device) for crop_sample in crop_samples] 21 | crop_masks = [crop_mask.to(device) for crop_mask in mask_samples] 22 | targets = [{k: v.to(device) if k != 'obj_names' and type(v) != int else v for k, v in t.items()} for t in targets] 23 | inputs = [ 24 | { 25 | "samples": samples, 26 | "crop_samples": crop_samples, 27 | "crop_masks": crop_masks 28 | } 29 | ] 30 | with autocast(): 31 | outputs = model(inputs, targets, args) 32 | 33 | loss_dict = criterion(outputs, targets) 34 | weight_dict = criterion.weight_dict 35 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 36 | loss_dict_reduced = utils.reduce_dict(loss_dict) 37 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} 38 | loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} 39 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 40 | loss_value = losses_reduced_scaled.item() 41 | 42 | if not math.isfinite(loss_value): 43 | print("Loss is {}, stopping training".format(loss_value)) 44 | print(loss_dict_reduced) 45 | sys.exit(1) 46 | 47 | optimizer.zero_grad() 48 | losses.backward() 49 | if max_norm > 0: 50 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 51 | optimizer.step() 52 | 53 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 54 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 55 | 56 | metric_logger.synchronize_between_processes() 57 | print("Averaged stats:", metric_logger) 58 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 59 | 60 | @torch.no_grad() 61 | def evaluate_att(model, data_loader, device, args=None): 62 | metric_logger = utils.MetricLogger(delimiter=" ") 63 | print_freq = 10 64 | model.eval() 65 | header = 'Test:' 66 | preds = [] 67 | gts = [] 68 | for samples, crop_samples, crop_masks, targets in metric_logger.log_every(data_loader, print_freq, header): 69 | samples = samples.tensors.to(device) 70 | crop_samples = [crop_sample.to(device) for crop_sample in crop_samples] 71 | crop_masks = [crop_mask.to(device) for crop_mask in crop_masks] 72 | inputs = [{"samples": samples, "crop_samples": crop_samples, "crop_masks": crop_masks}] 73 | targets = [{k: v.to(device) if k != 'obj_names' and type(v) != int else v for k, v in t.items()} for t in targets] 74 | with autocast(): 75 | outputs = model(inputs, targets, args)['pred_logits'] 76 | outputs = outputs.sigmoid() 77 | preds.extend(outputs.detach().cpu().numpy()) 78 | gt = preprocess_pos_neg(targets) 79 | gts.extend(gt) 80 | 81 | metric_logger.synchronize_between_processes() 82 | preds = torch.cat(utils.all_gather(torch.from_numpy(np.array(preds)))) 83 | annos = torch.cat(utils.all_gather(torch.from_numpy(np.array(gts)))) 84 | evaluator = Evaluator(args.fpath_attribute_index, args.fpath_head_tail) 85 | scores_per_class = evaluator.evaluate(preds, annos) 86 | CATEGORIES = ['all', 'head', 'medium', 'tail'] 87 | stats = {f'mAP_{category}': scores_per_class[category]['ap'] for category in CATEGORIES} 88 | if args.mode == 'zero_shot': 89 | stats.update(compute_zero_shot_mAP(evaluator, args)) 90 | 91 | return stats 92 | 93 | def compute_zero_shot_mAP(evaluator, args): 94 | base_novel_dict = json.load(open(args.base_novel_dict, 'r')) 95 | base_class = [v for _, v in base_novel_dict['base'].items()] 96 | novel_class = [v for _, v in base_novel_dict['novel'].items()] 97 | base_mAP = sum(evaluator.get_score_class(i_class).ap for i_class in base_class) / len(base_class) 98 | novel_mAP = sum(evaluator.get_score_class(i_class).ap for i_class in novel_class) / len(novel_class) 99 | return {'mAP_base': base_mAP, 'mAP_novel': novel_mAP} -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, DistributedSampler 2 | from engine import evaluate_att, train_one_epoch 3 | from datasets import build_dataset 4 | from models import build_model 5 | import util.misc as utils 6 | from pathlib import Path 7 | import numpy as np 8 | import argparse 9 | import datetime 10 | import random 11 | import json 12 | import time 13 | import torch 14 | import os 15 | 16 | def get_args_parser(): 17 | parser = argparse.ArgumentParser('Set SugaFormer', add_help=False) 18 | parser.add_argument('--lr', default=1e-4, type=float) 19 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 20 | parser.add_argument('--batch_size', default=2, type=int) 21 | parser.add_argument('--weight_decay', default=1e-4, type=float) 22 | parser.add_argument('--epochs', default=10, type=int) 23 | parser.add_argument('--lr_drop', default=8, type=int) 24 | parser.add_argument('--clip_max_norm', default=0.1, type=float, help="gradient clipping max norm") 25 | parser.add_argument('--pretrained', type=str, default='') 26 | parser.add_argument('--gamma_neg', default=4, type=int, help="gamma_neg for Assymloss") 27 | parser.add_argument('--clip', default=0.05, type=float) 28 | parser.add_argument('--position_embedding', default='learned', type=str, choices=('sine', 'learned'), 29 | help="Type of positional embedding to use on top of the image features") 30 | parser.add_argument('--freeze_backbone', action='store_true', help="freezing backbone flag") 31 | 32 | parser.add_argument('--dataset_file', default='', type=str) 33 | parser.add_argument('--vaw_path', default='data/vaw', type=str) 34 | parser.add_argument('--output_dir', default='', help="path where to save, empty for no saving") 35 | 36 | parser.add_argument('--device', default='cuda', help="device to use for training / testing") 37 | parser.add_argument('--seed', default=42, type=int) 38 | parser.add_argument('--resume', default='', help="resume from checkpoint") 39 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help="start epoch") 40 | parser.add_argument('--eval', action='store_true') 41 | parser.add_argument('--num_workers', default=4, type=int) 42 | parser.add_argument('--world_size', default=1, type=int, help="number of distributed processes") 43 | parser.add_argument('--dist_url', default='env://', help="url used to set up distributed training") 44 | 45 | 46 | parser.add_argument('--enc_layers', default=0, type=int, help="Number of encoding layers in the transformer") 47 | parser.add_argument('--dec_layers', default=3, type=int, help="Number of decoding layers in the transformer") 48 | parser.add_argument('--dim_feedforward', default=2048, type=int, help="Intermediate size of the feedforward layers in the transformer blocks") 49 | parser.add_argument('--hidden_dim', default=256, type=int, help="Size of the embeddings (dimension of the transformer)") 50 | parser.add_argument('--image_hidden_dim', default=1408, type=int, help="Visual Backbone output dimension") 51 | parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer") 52 | parser.add_argument('--nheads', default=8, type=int, help="Number of attention heads inside the transformer's attentions") 53 | parser.add_argument('--pre_norm', action='store_true') 54 | 55 | parser.add_argument('--num_obj_classes', default=2260, type=int, help="Number of object classes") 56 | parser.add_argument('--num_att_classes', default=620, type=int, help="Number of attribute classes") 57 | parser.add_argument('--att_loss_coef', default=1, type=float, help="attribute loss coefficient") 58 | parser.add_argument('--scr_coef', default=2, type=float, help="scr loss coefficient") 59 | 60 | parser.add_argument('--fpath_attribute_index', type=str, default='data/vaw/annotations/attribute_index.json') 61 | parser.add_argument('--fpath_head_tail', type=str, default='data/vaw/annotations/head_tail.json') 62 | parser.add_argument('--sc_feats', default='data/vaw/annotations/sc_embedding.pt', type=str, help="super-class text embedding") 63 | parser.add_argument('--att_feats', default='data/vaw/annotations/att_embedding.pt', type=str, help="attribute text embedding") 64 | parser.add_argument('--hierarchy', default='data/vaw/annotations/hierarchy.json', type=str, help="super-class to attribute hierarchy") 65 | parser.add_argument('--att_class_weight', default='data/vaw/annotations/att_class_weight.pt', type=str, help="weight for attribute classes") 66 | parser.add_argument('--base_novel_dict', default="data/vaw/annotations/base_novel_dict.json", type=str, help="base, novel index in attribute classes") 67 | 68 | parser.add_argument('--mode', default='', type=str, help="", choices=('zero_shot', 'supervised')) 69 | parser.add_argument('--use_scr', default='data/vaw/annotations/scr_tokens/', type=str, help="[MASK] feature for scr loss") 70 | parser.add_argument('--zrse', action='store_true', help="use zero-shot retrieval-based score enhancement") 71 | parser.add_argument('--zrse_scale', default=2, type=float) 72 | parser.add_argument('--ztopk',default=2, type=int) 73 | 74 | return parser 75 | 76 | 77 | 78 | def main(args): 79 | utils.init_distributed_mode(args) 80 | device = torch.device(args.device) 81 | seed = args.seed + utils.get_rank() 82 | torch.manual_seed(seed) 83 | np.random.seed(seed) 84 | random.seed(seed) 85 | torch.backends.cudnn.benchmark = False 86 | os.environ["PYTHONHASHSEED"] = str(seed) 87 | 88 | model, criterion = build_model(args) 89 | model.to(device) 90 | model_without_ddp = model 91 | 92 | if args.distributed: 93 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 94 | model_without_ddp = model.module 95 | 96 | param_dicts = [ 97 | {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, 98 | { 99 | "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], 100 | "lr": args.lr_backbone, 101 | }, 102 | ] 103 | 104 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 105 | weight_decay=args.weight_decay) 106 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 107 | 108 | dataset_train = build_dataset(image_set='train', args=args) 109 | dataset_test = build_dataset(image_set='test', args=args) 110 | 111 | if args.distributed: 112 | sampler_train = DistributedSampler(dataset_train) 113 | sampler_test = DistributedSampler(dataset_test, shuffle=False) 114 | 115 | else: 116 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 117 | sampler_test = torch.utils.data.SequentialSampler(dataset_test) 118 | 119 | batch_sampler_train = torch.utils.data.BatchSampler( 120 | sampler_train, args.batch_size, drop_last=True) 121 | 122 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 123 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 124 | data_loader_test = DataLoader(dataset_test, args.batch_size, sampler=sampler_test, 125 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) 126 | 127 | output_dir = Path(args.output_dir) 128 | if args.resume: 129 | if args.resume.startswith("https"): 130 | checkpoint = torch.hub.load_state_dict_from_url( 131 | args.resume, map_location='cpu', check_hash=True) 132 | else: 133 | checkpoint = torch.load(args.resume, map_location=("cpu")) 134 | 135 | model_without_ddp.load_state_dict(checkpoint['model']) 136 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 137 | optimizer.load_state_dict(checkpoint['optimizer']) 138 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 139 | args.start_epoch = checkpoint['epoch'] + 1 140 | 141 | if args.pretrained: 142 | if args.eval: 143 | checkpoint = torch.load(args.pretrained, map_location="cpu") 144 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 145 | 146 | test_stats = evaluate_att(model, data_loader_test, device, args) 147 | log_stats = {**{f'test_{k}': v for k, v in test_stats.items()}} 148 | 149 | if args.output_dir and utils.is_main_process(): 150 | with (output_dir / "log.txt").open("a") as f: 151 | f.write(json.dumps(log_stats) + "\n") 152 | return 153 | 154 | print("Start training") 155 | start_time = time.time() 156 | for epoch in range(args.start_epoch, args.epochs): 157 | 158 | if args.distributed: 159 | sampler_train.set_epoch(epoch) 160 | 161 | train_one_epoch( 162 | model, criterion, data_loader_train, optimizer, device, epoch, 163 | args.clip_max_norm, args) 164 | 165 | lr_scheduler.step() 166 | if args.output_dir: 167 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 168 | if (epoch + 1) % 5 == 0: 169 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') 170 | for checkpoint_path in checkpoint_paths: 171 | utils.save_on_master({ 172 | 'model': model_without_ddp.state_dict(), 173 | 'optimizer': optimizer.state_dict(), 174 | 'lr_scheduler': lr_scheduler.state_dict(), 175 | 'epoch': epoch, 176 | 'args': args, 177 | }, checkpoint_path) 178 | 179 | total_time = time.time() - start_time 180 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 181 | print('Training time {}'.format(total_time_str)) 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser('SugaFormer training and evaluation script', parents=[get_args_parser()]) 186 | args = parser.parse_args() 187 | if args.output_dir: 188 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 189 | main(args) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .sugaformer import build 3 | 4 | def build_model(args): 5 | return build(args) 6 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | from util.misc import NestedTensor, nested_tensor_from_tensor_list 2 | from .lavis.models import load_model_and_preprocess 3 | from .position_encoding import build_position_encoding 4 | from typing import List 5 | from torch import nn 6 | import torch 7 | import math 8 | 9 | class Joiner(nn.Sequential): 10 | def __init__(self, backbone, position_embedding): 11 | super().__init__(backbone, position_embedding) 12 | 13 | def forward(self, tensor_list: NestedTensor): 14 | xs_list = self[0](tensor_list.tensors) 15 | out: List[NestedTensor] = [] 16 | pos = [] 17 | for xs in xs_list: 18 | xs = xs[:,1:,:] 19 | B, D, C = xs.shape 20 | D_s = int(math.sqrt(D)) 21 | x = xs.permute(0,2,1).view(B,C,D_s,D_s) 22 | x = nested_tensor_from_tensor_list([x_s for x_s in x]) 23 | out.append(x) 24 | pos.append(self[1](x).to(x.tensors.dtype)) 25 | return out, pos 26 | 27 | def build_blip_backbone(args): 28 | position_embedding = build_position_encoding(args) 29 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 30 | model = load_model_and_preprocess(name="blip2", model_type="pretrain", device=device) 31 | visual_encoder = model.visual_encoder 32 | visual_encoder = Joiner(visual_encoder, position_embedding) 33 | return visual_encoder, model 34 | -------------------------------------------------------------------------------- /models/lavis/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from .models import BaseModel, Blip2Base, Blip2Qformer 9 | from .common.registry import registry 10 | from .processors import BaseProcessor 11 | from omegaconf import OmegaConf 12 | import sys 13 | import os 14 | 15 | # Exported symbols for the package 16 | __all__ = ["registry", "BaseModel", "Blip2Base", "Blip2Qformer", "BaseProcessor"] 17 | 18 | root_dir = os.path.dirname(os.path.abspath(__file__)) 19 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 20 | 21 | registry.register_path("library_root", root_dir) 22 | repo_root = os.path.join(root_dir, "..") 23 | registry.register_path("repo_root", repo_root) 24 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 25 | registry.register_path("cache_root", cache_root) 26 | 27 | registry.register("MAX_INT", sys.maxsize) 28 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 29 | -------------------------------------------------------------------------------- /models/lavis/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | torch.distributed.init_process_group( 81 | backend=args.dist_backend, 82 | init_method=args.dist_url, 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | timeout=datetime.timedelta( 86 | days=365 87 | ), # allow auto-downloading and de-compressing 88 | ) 89 | torch.distributed.barrier() 90 | setup_for_distributed(args.rank == 0) 91 | 92 | 93 | def get_dist_info(): 94 | if torch.__version__ < "1.0": 95 | initialized = dist._initialized 96 | else: 97 | initialized = dist.is_initialized() 98 | if initialized: 99 | rank = dist.get_rank() 100 | world_size = dist.get_world_size() 101 | else: # non-distributed training 102 | rank = 0 103 | world_size = 1 104 | return rank, world_size 105 | 106 | 107 | def main_process(func): 108 | @functools.wraps(func) 109 | def wrapper(*args, **kwargs): 110 | rank, _ = get_dist_info() 111 | if rank == 0: 112 | return func(*args, **kwargs) 113 | 114 | return wrapper 115 | 116 | 117 | def download_cached_file(url, check_hash=True, progress=False): 118 | """ 119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 121 | """ 122 | 123 | def get_cached_file_path(): 124 | # a hack to sync the file path across processes 125 | parts = torch.hub.urlparse(url) 126 | filename = os.path.basename(parts.path) 127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 128 | 129 | return cached_file 130 | 131 | if is_main_process(): 132 | timm_hub.download_cached_file(url, check_hash, progress) 133 | 134 | if is_dist_avail_and_initialized(): 135 | dist.barrier() 136 | 137 | return get_cached_file_path() 138 | -------------------------------------------------------------------------------- /models/lavis/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /models/lavis/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import time 11 | from collections import defaultdict, deque 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from lavis.common import dist_utils 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not dist_utils.is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value, 79 | ) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError( 100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 101 | ) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def global_avg(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) 113 | return self.delimiter.join(loss_str) 114 | 115 | def synchronize_between_processes(self): 116 | for meter in self.meters.values(): 117 | meter.synchronize_between_processes() 118 | 119 | def add_meter(self, name, meter): 120 | self.meters[name] = meter 121 | 122 | def log_every(self, iterable, print_freq, header=None): 123 | i = 0 124 | if not header: 125 | header = "" 126 | start_time = time.time() 127 | end = time.time() 128 | iter_time = SmoothedValue(fmt="{avg:.4f}") 129 | data_time = SmoothedValue(fmt="{avg:.4f}") 130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 131 | log_msg = [ 132 | header, 133 | "[{0" + space_fmt + "}/{1}]", 134 | "eta: {eta}", 135 | "{meters}", 136 | "time: {time}", 137 | "data: {data}", 138 | ] 139 | if torch.cuda.is_available(): 140 | log_msg.append("max mem: {memory:.0f}") 141 | log_msg = self.delimiter.join(log_msg) 142 | MB = 1024.0 * 1024.0 143 | for obj in iterable: 144 | data_time.update(time.time() - end) 145 | yield obj 146 | iter_time.update(time.time() - end) 147 | if i % print_freq == 0 or i == len(iterable) - 1: 148 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 150 | if torch.cuda.is_available(): 151 | print( 152 | log_msg.format( 153 | i, 154 | len(iterable), 155 | eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), 158 | data=str(data_time), 159 | memory=torch.cuda.max_memory_allocated() / MB, 160 | ) 161 | ) 162 | else: 163 | print( 164 | log_msg.format( 165 | i, 166 | len(iterable), 167 | eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), 170 | data=str(data_time), 171 | ) 172 | ) 173 | i += 1 174 | end = time.time() 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print( 178 | "{} Total time: {} ({:.4f} s / it)".format( 179 | header, total_time_str, total_time / len(iterable) 180 | ) 181 | ) 182 | 183 | 184 | class AttrDict(dict): 185 | def __init__(self, *args, **kwargs): 186 | super(AttrDict, self).__init__(*args, **kwargs) 187 | self.__dict__ = self 188 | 189 | 190 | def setup_logger(): 191 | logging.basicConfig( 192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN, 193 | format="%(asctime)s [%(levelname)s] %(message)s", 194 | handlers=[logging.StreamHandler()], 195 | ) 196 | -------------------------------------------------------------------------------- /models/lavis/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from lavis.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | **kwargs 25 | ): 26 | self.optimizer = optimizer 27 | 28 | self.max_epoch = max_epoch 29 | self.min_lr = min_lr 30 | 31 | self.decay_rate = decay_rate 32 | 33 | self.init_lr = init_lr 34 | self.warmup_steps = warmup_steps 35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 36 | 37 | def step(self, cur_epoch, cur_step): 38 | if cur_epoch == 0: 39 | warmup_lr_schedule( 40 | step=cur_step, 41 | optimizer=self.optimizer, 42 | max_step=self.warmup_steps, 43 | init_lr=self.warmup_start_lr, 44 | max_lr=self.init_lr, 45 | ) 46 | else: 47 | step_lr_schedule( 48 | epoch=cur_epoch, 49 | optimizer=self.optimizer, 50 | init_lr=self.init_lr, 51 | min_lr=self.min_lr, 52 | decay_rate=self.decay_rate, 53 | ) 54 | 55 | 56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 57 | class LinearWarmupCosineLRScheduler: 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_epoch, 62 | min_lr, 63 | init_lr, 64 | warmup_steps=0, 65 | warmup_start_lr=-1, 66 | **kwargs 67 | ): 68 | self.optimizer = optimizer 69 | 70 | self.max_epoch = max_epoch 71 | self.min_lr = min_lr 72 | 73 | self.init_lr = init_lr 74 | self.warmup_steps = warmup_steps 75 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 76 | 77 | def step(self, cur_epoch, cur_step): 78 | # assuming the warmup iters less than one epoch 79 | if cur_epoch == 0: 80 | warmup_lr_schedule( 81 | step=cur_step, 82 | optimizer=self.optimizer, 83 | max_step=self.warmup_steps, 84 | init_lr=self.warmup_start_lr, 85 | max_lr=self.init_lr, 86 | ) 87 | else: 88 | cosine_lr_schedule( 89 | epoch=cur_epoch, 90 | optimizer=self.optimizer, 91 | max_epoch=self.max_epoch, 92 | init_lr=self.init_lr, 93 | min_lr=self.min_lr, 94 | ) 95 | 96 | 97 | @registry.register_lr_scheduler("constant_lr") 98 | class ConstantLRScheduler: 99 | def __init__(self, optimizer, init_lr, warmup_start_lr=-1, warmup_steps=0, **kwargs): 100 | self.optimizer = optimizer 101 | self.lr = init_lr 102 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 103 | self.warmup_steps = warmup_steps 104 | 105 | def step(self, cur_epoch, cur_step): 106 | if cur_epoch == 0: 107 | warmup_lr_schedule( 108 | step=cur_step, 109 | optimizer=self.optimizer, 110 | max_step=self.warmup_steps, 111 | init_lr=self.warmup_start_lr, 112 | max_lr=self.lr, 113 | ) 114 | else: 115 | for param_group in self.optimizer.param_groups: 116 | param_group["lr"] = self.lr 117 | 118 | 119 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 120 | """Decay the learning rate""" 121 | lr = (init_lr - min_lr) * 0.5 * ( 122 | 1.0 + math.cos(math.pi * epoch / max_epoch) 123 | ) + min_lr 124 | for param_group in optimizer.param_groups: 125 | param_group["lr"] = lr 126 | 127 | 128 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 129 | """Warmup the learning rate""" 130 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 131 | for param_group in optimizer.param_groups: 132 | param_group["lr"] = lr 133 | 134 | 135 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 136 | """Decay the learning rate""" 137 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 138 | for param_group in optimizer.param_groups: 139 | param_group["lr"] = lr 140 | -------------------------------------------------------------------------------- /models/lavis/common/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | class Registry: 10 | mapping = { 11 | "builder_name_mapping": {}, 12 | "task_name_mapping": {}, 13 | "processor_name_mapping": {}, 14 | "model_name_mapping": {}, 15 | "lr_scheduler_name_mapping": {}, 16 | "runner_name_mapping": {}, 17 | "state": {}, 18 | "paths": {}, 19 | } 20 | 21 | # @classmethod 22 | # def register_builder(cls, name): 23 | # r"""Register a dataset builder to registry with key 'name' 24 | 25 | # Args: 26 | # name: Key with which the builder will be registered. 27 | 28 | # Usage: 29 | 30 | # from lavis.common.registry import registry 31 | # from lavis.datasets.base_dataset_builder import BaseDatasetBuilder 32 | # """ 33 | 34 | # def wrap(builder_cls): 35 | # from LAVIS.lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder 36 | 37 | # assert issubclass( 38 | # builder_cls, BaseDatasetBuilder 39 | # ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 40 | # builder_cls 41 | # ) 42 | # if name in cls.mapping["builder_name_mapping"]: 43 | # raise KeyError( 44 | # "Name '{}' already registered for {}.".format( 45 | # name, cls.mapping["builder_name_mapping"][name] 46 | # ) 47 | # ) 48 | # cls.mapping["builder_name_mapping"][name] = builder_cls 49 | # return builder_cls 50 | 51 | # return wrap 52 | 53 | @classmethod 54 | def register_task(cls, name): 55 | r"""Register a task to registry with key 'name' 56 | 57 | Args: 58 | name: Key with which the task will be registered. 59 | 60 | Usage: 61 | 62 | from lavis.common.registry import registry 63 | """ 64 | 65 | def wrap(task_cls): 66 | from models.lavis.tasks.base_task import BaseTask 67 | 68 | assert issubclass( 69 | task_cls, BaseTask 70 | ), "All tasks must inherit BaseTask class" 71 | if name in cls.mapping["task_name_mapping"]: 72 | raise KeyError( 73 | "Name '{}' already registered for {}.".format( 74 | name, cls.mapping["task_name_mapping"][name] 75 | ) 76 | ) 77 | cls.mapping["task_name_mapping"][name] = task_cls 78 | return task_cls 79 | 80 | return wrap 81 | 82 | @classmethod 83 | def register_model(cls, name): 84 | r"""Register a task to registry with key 'name' 85 | 86 | Args: 87 | name: Key with which the task will be registered. 88 | 89 | Usage: 90 | 91 | from lavis.common.registry import registry 92 | """ 93 | 94 | def wrap(model_cls): 95 | from models.lavis.models import BaseModel 96 | assert issubclass( 97 | model_cls, BaseModel 98 | ), "All models must inherit BaseModel class" 99 | if name in cls.mapping["model_name_mapping"]: 100 | raise KeyError( 101 | "Name '{}' already registered for {}.".format( 102 | name, cls.mapping["model_name_mapping"][name] 103 | ) 104 | ) 105 | cls.mapping["model_name_mapping"][name] = model_cls 106 | return model_cls 107 | 108 | return wrap 109 | 110 | @classmethod 111 | def register_processor(cls, name): 112 | r"""Register a processor to registry with key 'name' 113 | 114 | Args: 115 | name: Key with which the task will be registered. 116 | 117 | Usage: 118 | 119 | from lavis.common.registry import registry 120 | """ 121 | 122 | def wrap(processor_cls): 123 | from models.lavis.processors import BaseProcessor 124 | assert issubclass( 125 | processor_cls, BaseProcessor 126 | ), "All processors must inherit BaseProcessor class" 127 | if name in cls.mapping["processor_name_mapping"]: 128 | raise KeyError( 129 | "Name '{}' already registered for {}.".format( 130 | name, cls.mapping["processor_name_mapping"][name] 131 | ) 132 | ) 133 | cls.mapping["processor_name_mapping"][name] = processor_cls 134 | return processor_cls 135 | 136 | return wrap 137 | 138 | @classmethod 139 | def register_lr_scheduler(cls, name): 140 | r"""Register a model to registry with key 'name' 141 | 142 | Args: 143 | name: Key with which the task will be registered. 144 | 145 | Usage: 146 | 147 | from lavis.common.registry import registry 148 | """ 149 | 150 | def wrap(lr_sched_cls): 151 | if name in cls.mapping["lr_scheduler_name_mapping"]: 152 | raise KeyError( 153 | "Name '{}' already registered for {}.".format( 154 | name, cls.mapping["lr_scheduler_name_mapping"][name] 155 | ) 156 | ) 157 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls 158 | return lr_sched_cls 159 | 160 | return wrap 161 | 162 | @classmethod 163 | def register_runner(cls, name): 164 | r"""Register a model to registry with key 'name' 165 | 166 | Args: 167 | name: Key with which the task will be registered. 168 | 169 | Usage: 170 | 171 | from lavis.common.registry import registry 172 | """ 173 | 174 | def wrap(runner_cls): 175 | if name in cls.mapping["runner_name_mapping"]: 176 | raise KeyError( 177 | "Name '{}' already registered for {}.".format( 178 | name, cls.mapping["runner_name_mapping"][name] 179 | ) 180 | ) 181 | cls.mapping["runner_name_mapping"][name] = runner_cls 182 | return runner_cls 183 | 184 | return wrap 185 | 186 | @classmethod 187 | def register_path(cls, name, path): 188 | r"""Register a path to registry with key 'name' 189 | 190 | Args: 191 | name: Key with which the path will be registered. 192 | 193 | Usage: 194 | 195 | from lavis.common.registry import registry 196 | """ 197 | assert isinstance(path, str), "All path must be str." 198 | if name in cls.mapping["paths"]: 199 | raise KeyError("Name '{}' already registered.".format(name)) 200 | cls.mapping["paths"][name] = path 201 | 202 | @classmethod 203 | def register(cls, name, obj): 204 | r"""Register an item to registry with key 'name' 205 | 206 | Args: 207 | name: Key with which the item will be registered. 208 | 209 | Usage:: 210 | 211 | from lavis.common.registry import registry 212 | 213 | registry.register("config", {}) 214 | """ 215 | path = name.split(".") 216 | current = cls.mapping["state"] 217 | 218 | for part in path[:-1]: 219 | if part not in current: 220 | current[part] = {} 221 | current = current[part] 222 | 223 | current[path[-1]] = obj 224 | 225 | # @classmethod 226 | # def get_trainer_class(cls, name): 227 | # return cls.mapping["trainer_name_mapping"].get(name, None) 228 | 229 | @classmethod 230 | def get_builder_class(cls, name): 231 | return cls.mapping["builder_name_mapping"].get(name, None) 232 | 233 | @classmethod 234 | def get_model_class(cls, name): 235 | return cls.mapping["model_name_mapping"].get(name, None) 236 | 237 | @classmethod 238 | def get_task_class(cls, name): 239 | return cls.mapping["task_name_mapping"].get(name, None) 240 | 241 | @classmethod 242 | def get_processor_class(cls, name): 243 | return cls.mapping["processor_name_mapping"].get(name, None) 244 | 245 | @classmethod 246 | def get_lr_scheduler_class(cls, name): 247 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None) 248 | 249 | @classmethod 250 | def get_runner_class(cls, name): 251 | return cls.mapping["runner_name_mapping"].get(name, None) 252 | 253 | @classmethod 254 | def list_runners(cls): 255 | return sorted(cls.mapping["runner_name_mapping"].keys()) 256 | 257 | @classmethod 258 | def list_models(cls): 259 | return sorted(cls.mapping["model_name_mapping"].keys()) 260 | 261 | @classmethod 262 | def list_tasks(cls): 263 | return sorted(cls.mapping["task_name_mapping"].keys()) 264 | 265 | @classmethod 266 | def list_processors(cls): 267 | return sorted(cls.mapping["processor_name_mapping"].keys()) 268 | 269 | @classmethod 270 | def list_lr_schedulers(cls): 271 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) 272 | 273 | @classmethod 274 | def list_datasets(cls): 275 | return sorted(cls.mapping["builder_name_mapping"].keys()) 276 | 277 | @classmethod 278 | def get_path(cls, name): 279 | return cls.mapping["paths"].get(name, None) 280 | 281 | @classmethod 282 | def get(cls, name, default=None, no_warning=False): 283 | r"""Get an item from registry with key 'name' 284 | 285 | Args: 286 | name (string): Key whose value needs to be retrieved. 287 | default: If passed and key is not in registry, default value will 288 | be returned with a warning. Default: None 289 | no_warning (bool): If passed as True, warning when key doesn't exist 290 | will not be generated. Useful for MMF's 291 | internal operations. Default: False 292 | """ 293 | original_name = name 294 | name = name.split(".") 295 | value = cls.mapping["state"] 296 | for subname in name: 297 | value = value.get(subname, default) 298 | if value is default: 299 | break 300 | 301 | if ( 302 | "writer" in cls.mapping["state"] 303 | and value == default 304 | and no_warning is False 305 | ): 306 | cls.mapping["state"]["writer"].warning( 307 | "Key {} is not present in registry, returning default value " 308 | "of {}".format(original_name, default) 309 | ) 310 | return value 311 | 312 | @classmethod 313 | def unregister(cls, name): 314 | r"""Remove an item from registry with key 'name' 315 | 316 | Args: 317 | name: Key which needs to be removed. 318 | Usage:: 319 | 320 | from mmf.common.registry import registry 321 | 322 | config = registry.unregister("config") 323 | """ 324 | return cls.mapping["state"].pop(name, None) 325 | 326 | 327 | registry = Registry() 328 | -------------------------------------------------------------------------------- /models/lavis/common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import io 9 | import json 10 | import logging 11 | import os 12 | import pickle 13 | import re 14 | import shutil 15 | import tarfile 16 | import urllib 17 | import urllib.error 18 | import urllib.request 19 | from typing import Optional 20 | from urllib.parse import urlparse 21 | 22 | import numpy as np 23 | import pandas as pd 24 | import yaml 25 | from iopath.common.download import download 26 | from iopath.common.file_io import file_lock, g_pathmgr 27 | from lavis.common.dist_utils import download_cached_file 28 | from lavis.common.registry import registry 29 | from torch.utils.model_zoo import tqdm 30 | from torchvision.datasets.utils import ( 31 | check_integrity, 32 | download_file_from_google_drive, 33 | extract_archive, 34 | ) 35 | 36 | 37 | def now(): 38 | from datetime import datetime 39 | 40 | return datetime.now().strftime("%Y%m%d%H%M")[:-1] 41 | 42 | 43 | def is_url(url_or_filename): 44 | parsed = urlparse(url_or_filename) 45 | return parsed.scheme in ("http", "https") 46 | 47 | 48 | def get_cache_path(rel_path): 49 | return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) 50 | 51 | 52 | def get_abs_path(rel_path): 53 | return os.path.join(registry.get_path("library_root"), rel_path) 54 | 55 | 56 | def load_json(filename): 57 | with open(filename, "r") as f: 58 | return json.load(f) 59 | 60 | 61 | # The following are adapted from torchvision and vissl 62 | # torchvision: https://github.com/pytorch/vision 63 | # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py 64 | 65 | 66 | def makedir(dir_path): 67 | """ 68 | Create the directory if it does not exist. 69 | """ 70 | is_success = False 71 | try: 72 | if not g_pathmgr.exists(dir_path): 73 | g_pathmgr.mkdirs(dir_path) 74 | is_success = True 75 | except BaseException: 76 | print(f"Error creating directory: {dir_path}") 77 | return is_success 78 | 79 | 80 | def get_redirected_url(url: str): 81 | """ 82 | Given a URL, returns the URL it redirects to or the 83 | original URL in case of no indirection 84 | """ 85 | import requests 86 | 87 | with requests.Session() as session: 88 | with session.get(url, stream=True, allow_redirects=True) as response: 89 | if response.history: 90 | return response.url 91 | else: 92 | return url 93 | 94 | 95 | def to_google_drive_download_url(view_url: str) -> str: 96 | """ 97 | Utility function to transform a view URL of google drive 98 | to a download URL for google drive 99 | Example input: 100 | https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view 101 | Example output: 102 | https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp 103 | """ 104 | splits = view_url.split("/") 105 | assert splits[-1] == "view" 106 | file_id = splits[-2] 107 | return f"https://drive.google.com/uc?export=download&id={file_id}" 108 | 109 | 110 | def download_google_drive_url(url: str, output_path: str, output_file_name: str): 111 | """ 112 | Download a file from google drive 113 | Downloading an URL from google drive requires confirmation when 114 | the file of the size is too big (google drive notifies that 115 | anti-viral checks cannot be performed on such files) 116 | """ 117 | import requests 118 | 119 | with requests.Session() as session: 120 | 121 | # First get the confirmation token and append it to the URL 122 | with session.get(url, stream=True, allow_redirects=True) as response: 123 | for k, v in response.cookies.items(): 124 | if k.startswith("download_warning"): 125 | url = url + "&confirm=" + v 126 | 127 | # Then download the content of the file 128 | with session.get(url, stream=True, verify=True) as response: 129 | makedir(output_path) 130 | path = os.path.join(output_path, output_file_name) 131 | total_size = int(response.headers.get("Content-length", 0)) 132 | with open(path, "wb") as file: 133 | from tqdm import tqdm 134 | 135 | with tqdm(total=total_size) as progress_bar: 136 | for block in response.iter_content( 137 | chunk_size=io.DEFAULT_BUFFER_SIZE 138 | ): 139 | file.write(block) 140 | progress_bar.update(len(block)) 141 | 142 | 143 | def _get_google_drive_file_id(url: str) -> Optional[str]: 144 | parts = urlparse(url) 145 | 146 | if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: 147 | return None 148 | 149 | match = re.match(r"/file/d/(?P[^/]*)", parts.path) 150 | if match is None: 151 | return None 152 | 153 | return match.group("id") 154 | 155 | 156 | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: 157 | with open(filename, "wb") as fh: 158 | with urllib.request.urlopen( 159 | urllib.request.Request(url, headers={"User-Agent": "vissl"}) 160 | ) as response: 161 | with tqdm(total=response.length) as pbar: 162 | for chunk in iter(lambda: response.read(chunk_size), ""): 163 | if not chunk: 164 | break 165 | pbar.update(chunk_size) 166 | fh.write(chunk) 167 | 168 | 169 | def download_url( 170 | url: str, 171 | root: str, 172 | filename: Optional[str] = None, 173 | md5: Optional[str] = None, 174 | ) -> None: 175 | """Download a file from a url and place it in root. 176 | Args: 177 | url (str): URL to download file from 178 | root (str): Directory to place downloaded file in 179 | filename (str, optional): Name to save the file under. 180 | If None, use the basename of the URL. 181 | md5 (str, optional): MD5 checksum of the download. If None, do not check 182 | """ 183 | root = os.path.expanduser(root) 184 | if not filename: 185 | filename = os.path.basename(url) 186 | fpath = os.path.join(root, filename) 187 | 188 | makedir(root) 189 | 190 | # check if file is already present locally 191 | if check_integrity(fpath, md5): 192 | print("Using downloaded and verified file: " + fpath) 193 | return 194 | 195 | # expand redirect chain if needed 196 | url = get_redirected_url(url) 197 | 198 | # check if file is located on Google Drive 199 | file_id = _get_google_drive_file_id(url) 200 | if file_id is not None: 201 | return download_file_from_google_drive(file_id, root, filename, md5) 202 | 203 | # download the file 204 | try: 205 | print("Downloading " + url + " to " + fpath) 206 | _urlretrieve(url, fpath) 207 | except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] 208 | if url[:5] == "https": 209 | url = url.replace("https:", "http:") 210 | print( 211 | "Failed download. Trying https -> http instead." 212 | " Downloading " + url + " to " + fpath 213 | ) 214 | _urlretrieve(url, fpath) 215 | else: 216 | raise e 217 | 218 | # check integrity of downloaded file 219 | if not check_integrity(fpath, md5): 220 | raise RuntimeError("File not found or corrupted.") 221 | 222 | 223 | def download_and_extract_archive( 224 | url: str, 225 | download_root: str, 226 | extract_root: Optional[str] = None, 227 | filename: Optional[str] = None, 228 | md5: Optional[str] = None, 229 | remove_finished: bool = False, 230 | ) -> None: 231 | download_root = os.path.expanduser(download_root) 232 | if extract_root is None: 233 | extract_root = download_root 234 | if not filename: 235 | filename = os.path.basename(url) 236 | 237 | download_url(url, download_root, filename, md5) 238 | 239 | archive = os.path.join(download_root, filename) 240 | print("Extracting {} to {}".format(archive, extract_root)) 241 | extract_archive(archive, extract_root, remove_finished) 242 | 243 | 244 | def cache_url(url: str, cache_dir: str) -> str: 245 | """ 246 | This implementation downloads the remote resource and caches it locally. 247 | The resource will only be downloaded if not previously requested. 248 | """ 249 | parsed_url = urlparse(url) 250 | dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) 251 | makedir(dirname) 252 | filename = url.split("/")[-1] 253 | cached = os.path.join(dirname, filename) 254 | with file_lock(cached): 255 | if not os.path.isfile(cached): 256 | logging.info(f"Downloading {url} to {cached} ...") 257 | cached = download(url, dirname, filename=filename) 258 | logging.info(f"URL {url} cached in {cached}") 259 | return cached 260 | 261 | 262 | # TODO (prigoyal): convert this into RAII-style API 263 | def create_file_symlink(file1, file2): 264 | """ 265 | Simply create the symlinks for a given file1 to file2. 266 | Useful during model checkpointing to symlinks to the 267 | latest successful checkpoint. 268 | """ 269 | try: 270 | if g_pathmgr.exists(file2): 271 | g_pathmgr.rm(file2) 272 | g_pathmgr.symlink(file1, file2) 273 | except Exception as e: 274 | logging.info(f"Could NOT create symlink. Error: {e}") 275 | 276 | 277 | def save_file(data, filename, append_to_json=True, verbose=True): 278 | """ 279 | Common i/o utility to handle saving data to various file formats. 280 | Supported: 281 | .pkl, .pickle, .npy, .json 282 | Specifically for .json, users have the option to either append (default) 283 | or rewrite by passing in Boolean value to append_to_json. 284 | """ 285 | if verbose: 286 | logging.info(f"Saving data to file: {filename}") 287 | file_ext = os.path.splitext(filename)[1] 288 | if file_ext in [".pkl", ".pickle"]: 289 | with g_pathmgr.open(filename, "wb") as fopen: 290 | pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) 291 | elif file_ext == ".npy": 292 | with g_pathmgr.open(filename, "wb") as fopen: 293 | np.save(fopen, data) 294 | elif file_ext == ".json": 295 | if append_to_json: 296 | with g_pathmgr.open(filename, "a") as fopen: 297 | fopen.write(json.dumps(data, sort_keys=True) + "\n") 298 | fopen.flush() 299 | else: 300 | with g_pathmgr.open(filename, "w") as fopen: 301 | fopen.write(json.dumps(data, sort_keys=True) + "\n") 302 | fopen.flush() 303 | elif file_ext == ".yaml": 304 | with g_pathmgr.open(filename, "w") as fopen: 305 | dump = yaml.dump(data) 306 | fopen.write(dump) 307 | fopen.flush() 308 | else: 309 | raise Exception(f"Saving {file_ext} is not supported yet") 310 | 311 | if verbose: 312 | logging.info(f"Saved data to file: {filename}") 313 | 314 | 315 | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): 316 | """ 317 | Common i/o utility to handle loading data from various file formats. 318 | Supported: 319 | .pkl, .pickle, .npy, .json 320 | For the npy files, we support reading the files in mmap_mode. 321 | If the mmap_mode of reading is not successful, we load data without the 322 | mmap_mode. 323 | """ 324 | if verbose: 325 | logging.info(f"Loading data from file: {filename}") 326 | 327 | file_ext = os.path.splitext(filename)[1] 328 | if file_ext == ".txt": 329 | with g_pathmgr.open(filename, "r") as fopen: 330 | data = fopen.readlines() 331 | elif file_ext in [".pkl", ".pickle"]: 332 | with g_pathmgr.open(filename, "rb") as fopen: 333 | data = pickle.load(fopen, encoding="latin1") 334 | elif file_ext == ".npy": 335 | if mmap_mode: 336 | try: 337 | with g_pathmgr.open(filename, "rb") as fopen: 338 | data = np.load( 339 | fopen, 340 | allow_pickle=allow_pickle, 341 | encoding="latin1", 342 | mmap_mode=mmap_mode, 343 | ) 344 | except ValueError as e: 345 | logging.info( 346 | f"Could not mmap {filename}: {e}. Trying without g_pathmgr" 347 | ) 348 | data = np.load( 349 | filename, 350 | allow_pickle=allow_pickle, 351 | encoding="latin1", 352 | mmap_mode=mmap_mode, 353 | ) 354 | logging.info("Successfully loaded without g_pathmgr") 355 | except Exception: 356 | logging.info("Could not mmap without g_pathmgr. Trying without mmap") 357 | with g_pathmgr.open(filename, "rb") as fopen: 358 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") 359 | else: 360 | with g_pathmgr.open(filename, "rb") as fopen: 361 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") 362 | elif file_ext == ".json": 363 | with g_pathmgr.open(filename, "r") as fopen: 364 | data = json.load(fopen) 365 | elif file_ext == ".yaml": 366 | with g_pathmgr.open(filename, "r") as fopen: 367 | data = yaml.load(fopen, Loader=yaml.FullLoader) 368 | elif file_ext == ".csv": 369 | with g_pathmgr.open(filename, "r") as fopen: 370 | data = pd.read_csv(fopen) 371 | else: 372 | raise Exception(f"Reading from {file_ext} is not supported yet") 373 | return data 374 | 375 | 376 | def abspath(resource_path: str): 377 | """ 378 | Make a path absolute, but take into account prefixes like 379 | "http://" or "manifold://" 380 | """ 381 | regex = re.compile(r"^\w+://") 382 | if regex.match(resource_path) is None: 383 | return os.path.abspath(resource_path) 384 | else: 385 | return resource_path 386 | 387 | 388 | def makedir(dir_path): 389 | """ 390 | Create the directory if it does not exist. 391 | """ 392 | is_success = False 393 | try: 394 | if not g_pathmgr.exists(dir_path): 395 | g_pathmgr.mkdirs(dir_path) 396 | is_success = True 397 | except BaseException: 398 | logging.info(f"Error creating directory: {dir_path}") 399 | return is_success 400 | 401 | 402 | def is_url(input_url): 403 | """ 404 | Check if an input string is a url. look for http(s):// and ignoring the case 405 | """ 406 | is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None 407 | return is_url 408 | 409 | 410 | def download_and_untar(url): 411 | cached_file = download_cached_file( 412 | url, check_hash=False, progress=True 413 | ) 414 | # get path to untarred directory 415 | untarred_dir = os.path.basename(url).split(".")[0] 416 | parent_dir = os.path.dirname(cached_file) 417 | 418 | full_dir = os.path.join(parent_dir, untarred_dir) 419 | 420 | if not os.path.exists(full_dir): 421 | with tarfile.open(cached_file) as tar: 422 | tar.extractall(parent_dir) 423 | 424 | return full_dir 425 | 426 | def cleanup_dir(dir): 427 | """ 428 | Utility for deleting a directory. Useful for cleaning the storage space 429 | that contains various training artifacts like checkpoints, data etc. 430 | """ 431 | if os.path.exists(dir): 432 | logging.info(f"Deleting directory: {dir}") 433 | shutil.rmtree(dir) 434 | logging.info(f"Deleted contents of directory: {dir}") 435 | 436 | 437 | def get_file_size(filename): 438 | """ 439 | Given a file, get the size of file in MB 440 | """ 441 | size_in_mb = os.path.getsize(filename) / float(1024**2) 442 | return size_in_mb 443 | 444 | def is_serializable(value): 445 | """ 446 | This function checks if the provided value can be serialized into a JSON string. 447 | """ 448 | try: 449 | json.dumps(value) 450 | return True 451 | except (TypeError, OverflowError): 452 | return False 453 | 454 | def is_convertible_to_int(value): 455 | return bool(re.match(r'^-?\d+$', str(value))) -------------------------------------------------------------------------------- /models/lavis/configs/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | env: 7 | # For default users 8 | # cache_root: "cache" 9 | # For internal use with persistent storage 10 | cache_root: "/export/home/.cache/lavis" -------------------------------------------------------------------------------- /models/lavis/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from ..common.registry import registry 9 | from .base_model import BaseModel 10 | from .blip2_models.blip2 import Blip2Base 11 | from .blip2_models.blip2_qformer import Blip2Qformer 12 | from ..processors.base_processor import BaseProcessor 13 | from omegaconf import OmegaConf 14 | import logging 15 | import torch 16 | 17 | __all__ = ["registry", "BaseModel", "Blip2Base", "Blip2Qformer", "BaseProcessor"] 18 | 19 | 20 | # __all__ = [ 21 | # "load_model", 22 | # "AlbefClassification", 23 | # "AlbefFeatureExtractor", 24 | # "AlbefNLVR", 25 | # "AlbefVQA", 26 | # "AlbefPretrain", 27 | # "AlbefRetrieval", 28 | # "AlproQA", 29 | # "AlproRetrieval", 30 | # "BaseModel", 31 | # "BlipBase", 32 | # "BlipFeatureExtractor", 33 | # "BlipCaption", 34 | # "BlipClassification", 35 | # "BlipITM", 36 | # "BlipNLVR", 37 | # "BlipPretrain", 38 | # "BlipRetrieval", 39 | # "BlipVQA", 40 | # "Blip2Qformer", 41 | # "Blip2Base", 42 | # "Blip2ITM", 43 | # "Blip2OPT", 44 | # "Blip2T5", 45 | # "PNPVQA", 46 | # "Img2PromptVQA", 47 | # "PNPUnifiedQAv2FiD", 48 | # "CLIP", 49 | # "VisionTransformerEncoder", 50 | # "XBertLMHeadDecoder", 51 | # "GPTDialogue", 52 | # ] 53 | 54 | 55 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): 56 | """ 57 | Load supported models. 58 | 59 | To list all available models and types in registry: 60 | >>> from lavis.models import model_zoo 61 | >>> print(model_zoo) 62 | 63 | Args: 64 | name (str): name of the model. 65 | model_type (str): type of the model. 66 | is_eval (bool): whether the model is in eval mode. Default: False. 67 | device (str): device to use. Default: "cpu". 68 | checkpoint (str): path or to checkpoint. Default: None. 69 | Note that expecting the checkpoint to have the same keys in state_dict as the model. 70 | 71 | Returns: 72 | model (torch.nn.Module): model. 73 | """ 74 | 75 | model = registry.get_model_class(name).from_pretrained(model_type=model_type, force_download=True) 76 | 77 | if checkpoint is not None: 78 | model.load_checkpoint(checkpoint) 79 | 80 | if is_eval: 81 | model.eval() 82 | 83 | if device == "cpu": 84 | model = model.float() 85 | 86 | return model.to(device) 87 | 88 | 89 | def load_preprocess(config): 90 | """ 91 | Load preprocessor configs and construct preprocessors. 92 | 93 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. 94 | 95 | Args: 96 | config (dict): preprocessor configs. 97 | 98 | Returns: 99 | vis_processors (dict): preprocessors for visual inputs. 100 | txt_processors (dict): preprocessors for text inputs. 101 | 102 | Key is "train" or "eval" for processors used in training and evaluation respectively. 103 | """ 104 | 105 | def _build_proc_from_cfg(cfg): 106 | return ( 107 | registry.get_processor_class(cfg.name).from_config(cfg) 108 | if cfg is not None 109 | else BaseProcessor() 110 | ) 111 | 112 | vis_processors = dict() 113 | txt_processors = dict() 114 | 115 | vis_proc_cfg = config.get("vis_processor") 116 | txt_proc_cfg = config.get("text_processor") 117 | 118 | if vis_proc_cfg is not None: 119 | vis_train_cfg = vis_proc_cfg.get("train") 120 | vis_eval_cfg = vis_proc_cfg.get("eval") 121 | else: 122 | vis_train_cfg = None 123 | vis_eval_cfg = None 124 | 125 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) 126 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) 127 | 128 | if txt_proc_cfg is not None: 129 | txt_train_cfg = txt_proc_cfg.get("train") 130 | txt_eval_cfg = txt_proc_cfg.get("eval") 131 | else: 132 | txt_train_cfg = None 133 | txt_eval_cfg = None 134 | 135 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) 136 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) 137 | 138 | return vis_processors, txt_processors 139 | 140 | 141 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): 142 | """ 143 | Load model and its related preprocessors. 144 | 145 | List all available models and types in registry: 146 | >>> from lavis.models import model_zoo 147 | >>> print(model_zoo) 148 | 149 | Args: 150 | name (str): name of the model. 151 | model_type (str): type of the model. 152 | is_eval (bool): whether the model is in eval mode. Default: False. 153 | device (str): device to use. Default: "cpu". 154 | 155 | Returns: 156 | model (torch.nn.Module): model. 157 | vis_processors (dict): preprocessors for visual inputs. 158 | txt_processors (dict): preprocessors for text inputs. 159 | """ 160 | model_cls = registry.get_model_class(name) 161 | 162 | # load model 163 | model = model_cls.from_pretrained(model_type=model_type) 164 | 165 | if is_eval: 166 | model.eval() 167 | 168 | # load preprocess 169 | cfg = OmegaConf.load(model_cls.default_config_path(model_type)) 170 | if cfg is not None: 171 | preprocess_cfg = cfg.preprocess 172 | vis_processors, txt_processors = load_preprocess(preprocess_cfg) 173 | else: 174 | vis_processors, txt_processors = None, None 175 | logging.info( 176 | f"""No default preprocess for model {name} ({model_type}). 177 | This can happen if the model is not finetuned on downstream datasets, 178 | or it is not intended for direct use without finetuning. 179 | """ 180 | ) 181 | 182 | if device == "cpu" or device == torch.device("cpu"): 183 | model = model.float() 184 | 185 | return model.to(device) 186 | 187 | 188 | class ModelZoo: 189 | """ 190 | A utility class to create string representation of available model architectures and types. 191 | 192 | >>> from lavis.models import model_zoo 193 | >>> # list all available models 194 | >>> print(model_zoo) 195 | >>> # show total number of models 196 | >>> print(len(model_zoo)) 197 | """ 198 | 199 | def __init__(self) -> None: 200 | self.model_zoo = { 201 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) 202 | for k, v in registry.mapping["model_name_mapping"].items() 203 | } 204 | 205 | def __str__(self) -> str: 206 | return ( 207 | "=" * 50 208 | + "\n" 209 | + f"{'Architectures':<30} {'Types'}\n" 210 | + "=" * 50 211 | + "\n" 212 | + "\n".join( 213 | [ 214 | f"{name:<30} {', '.join(types)}" 215 | for name, types in self.model_zoo.items() 216 | ] 217 | ) 218 | ) 219 | 220 | def __iter__(self): 221 | return iter(self.model_zoo.items()) 222 | 223 | def __len__(self): 224 | return sum([len(v) for v in self.model_zoo.values()]) 225 | 226 | 227 | model_zoo = ModelZoo() 228 | -------------------------------------------------------------------------------- /models/lavis/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from ..common.dist_utils import download_cached_file, is_dist_avail_and_initialized 9 | from ..common.utils import get_abs_path, is_url 10 | from omegaconf import OmegaConf 11 | import torch.nn as nn 12 | import numpy as np 13 | import logging 14 | import torch 15 | import os 16 | 17 | 18 | class BaseModel(nn.Module): 19 | """Base class for models.""" 20 | 21 | def __init__(self): 22 | super().__init__() 23 | 24 | @property 25 | def device(self): 26 | return list(self.parameters())[0].device 27 | 28 | def load_checkpoint(self, url_or_filename): 29 | """ 30 | Load from a finetuned checkpoint. 31 | 32 | This should expect no mismatch in the model keys and the checkpoint keys. 33 | """ 34 | 35 | if is_url(url_or_filename): 36 | cached_file = download_cached_file( 37 | url_or_filename, check_hash=False, progress=True 38 | ) 39 | checkpoint = torch.load(cached_file, map_location="cpu") 40 | elif os.path.isfile(url_or_filename): 41 | checkpoint = torch.load(url_or_filename, map_location="cpu") 42 | else: 43 | raise RuntimeError("checkpoint url or path is invalid") 44 | 45 | if "model" in checkpoint.keys(): 46 | state_dict = checkpoint["model"] 47 | else: 48 | state_dict = checkpoint 49 | 50 | msg = self.load_state_dict(state_dict, strict=False) 51 | 52 | logging.info("Missing keys {}".format(msg.missing_keys)) 53 | logging.info("load checkpoint from %s" % url_or_filename) 54 | 55 | return msg 56 | 57 | @classmethod 58 | def from_pretrained(cls, model_type): 59 | """ 60 | Build a pretrained model from default configuration file, specified by model_type. 61 | 62 | Args: 63 | - model_type (str): model type, specifying architecture and checkpoints. 64 | 65 | Returns: 66 | - model (nn.Module): pretrained or finetuned model, depending on the configuration. 67 | """ 68 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model 69 | model = cls.from_config(model_cfg) 70 | 71 | return model 72 | 73 | @classmethod 74 | def default_config_path(cls, model_type): 75 | assert ( 76 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT 77 | ), "Unknown model type {}".format(model_type) 78 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) 79 | 80 | def load_checkpoint_from_config(self, cfg, **kwargs): 81 | """ 82 | Load checkpoint as specified in the config file. 83 | 84 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. 85 | When loading the pretrained model, each task-specific architecture may define their 86 | own load_from_pretrained() method. 87 | """ 88 | load_finetuned = cfg.get("load_finetuned", True) 89 | if load_finetuned: 90 | finetune_path = cfg.get("finetuned", None) 91 | assert ( 92 | finetune_path is not None 93 | ), "Found load_finetuned is True, but finetune_path is None." 94 | self.load_checkpoint(url_or_filename=finetune_path) 95 | else: 96 | load_pretrained = cfg.get("load_pretrained", True) 97 | if load_pretrained: 98 | # load pre-trained weights 99 | pretrain_path = cfg.get("pretrained", None) 100 | assert "Found load_finetuned is False, but pretrain_path is None." 101 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) 102 | 103 | def before_training(self, **kwargs): 104 | pass 105 | 106 | def get_optimizer_params(self, weight_decay, lr_scale=1): 107 | p_wd, p_non_wd = [], [] 108 | for n, p in self.named_parameters(): 109 | if not p.requires_grad: 110 | continue # frozen weights 111 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: 112 | p_non_wd.append(p) 113 | else: 114 | p_wd.append(p) 115 | optim_params = [ 116 | {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale}, 117 | {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale}, 118 | ] 119 | return optim_params 120 | 121 | def before_evaluation(self, **kwargs): 122 | pass 123 | 124 | def show_n_params(self, return_str=True): 125 | tot = 0 126 | for p in self.parameters(): 127 | w = 1 128 | for x in p.shape: 129 | w *= x 130 | tot += w 131 | if return_str: 132 | if tot >= 1e6: 133 | return "{:.1f}M".format(tot / 1e6) 134 | else: 135 | return "{:.1f}K".format(tot / 1e3) 136 | else: 137 | return tot 138 | 139 | 140 | class BaseEncoder(nn.Module): 141 | """ 142 | Base class for primitive encoders, such as ViT, TimeSformer, etc. 143 | """ 144 | 145 | def __init__(self): 146 | super().__init__() 147 | 148 | def forward_features(self, samples, **kwargs): 149 | raise NotImplementedError 150 | 151 | @property 152 | def device(self): 153 | return list(self.parameters())[0].device 154 | 155 | 156 | class SharedQueueMixin: 157 | @torch.no_grad() 158 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): 159 | # gather keys before updating queue 160 | image_feats = concat_all_gather(image_feat) 161 | text_feats = concat_all_gather(text_feat) 162 | 163 | batch_size = image_feats.shape[0] 164 | 165 | ptr = int(self.queue_ptr) 166 | assert self.queue_size % batch_size == 0 # for simplicity 167 | 168 | # replace the keys at ptr (dequeue and enqueue) 169 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T 170 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T 171 | 172 | if idxs is not None: 173 | idxs = concat_all_gather(idxs) 174 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T 175 | 176 | ptr = (ptr + batch_size) % self.queue_size # move pointer 177 | self.queue_ptr[0] = ptr 178 | 179 | 180 | class MomentumDistilationMixin: 181 | @torch.no_grad() 182 | def copy_params(self): 183 | for model_pair in self.model_pairs: 184 | for param, param_m in zip( 185 | model_pair[0].parameters(), model_pair[1].parameters() 186 | ): 187 | param_m.data.copy_(param.data) # initialize 188 | param_m.requires_grad = False # not update by gradient 189 | 190 | @torch.no_grad() 191 | def _momentum_update(self): 192 | for model_pair in self.model_pairs: 193 | for param, param_m in zip( 194 | model_pair[0].parameters(), model_pair[1].parameters() 195 | ): 196 | param_m.data = param_m.data * self.momentum + param.data * ( 197 | 1.0 - self.momentum 198 | ) 199 | 200 | 201 | class GatherLayer(torch.autograd.Function): 202 | """ 203 | Gather tensors from all workers with support for backward propagation: 204 | This implementation does not cut the gradients as torch.distributed.all_gather does. 205 | """ 206 | 207 | @staticmethod 208 | def forward(ctx, x): 209 | output = [ 210 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 211 | ] 212 | torch.distributed.all_gather(output, x) 213 | return tuple(output) 214 | 215 | @staticmethod 216 | def backward(ctx, *grads): 217 | all_gradients = torch.stack(grads) 218 | torch.distributed.all_reduce(all_gradients) 219 | return all_gradients[torch.distributed.get_rank()] 220 | 221 | 222 | def all_gather_with_grad(tensors): 223 | """ 224 | Performs all_gather operation on the provided tensors. 225 | Graph remains connected for backward grad computation. 226 | """ 227 | # Queue the gathered tensors 228 | world_size = torch.distributed.get_world_size() 229 | # There is no need for reduction in the single-proc case 230 | if world_size == 1: 231 | return tensors 232 | 233 | # tensor_all = GatherLayer.apply(tensors) 234 | tensor_all = GatherLayer.apply(tensors) 235 | 236 | return torch.cat(tensor_all, dim=0) 237 | 238 | 239 | @torch.no_grad() 240 | def concat_all_gather(tensor): 241 | """ 242 | Performs all_gather operation on the provided tensors. 243 | *** Warning ***: torch.distributed.all_gather has no gradient. 244 | """ 245 | # if use distributed training 246 | if not is_dist_avail_and_initialized(): 247 | return tensor 248 | 249 | tensors_gather = [ 250 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 251 | ] 252 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 253 | 254 | output = torch.cat(tensors_gather, dim=0) 255 | return output 256 | 257 | 258 | def tile(x, dim, n_tile): 259 | init_dim = x.size(dim) 260 | repeat_idx = [1] * x.dim() 261 | repeat_idx[dim] = n_tile 262 | x = x.repeat(*(repeat_idx)) 263 | order_index = torch.LongTensor( 264 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 265 | ) 266 | return torch.index_select(x, dim, order_index.to(x.device)) 267 | -------------------------------------------------------------------------------- /models/lavis/models/blip2_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/models/lavis/models/blip2_models/__init__.py -------------------------------------------------------------------------------- /models/lavis/models/blip2_models/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | from ..blip2_models.Qformer import BertConfig, BertLMHeadModel 8 | from ...common.dist_utils import download_cached_file 9 | import models.lavis.common.dist_utils as dist_utils 10 | from ...models.clip_vit import create_clip_vit_L 11 | from ...common.logger import MetricLogger 12 | from ..eva_vit import create_eva_vit_g 13 | from transformers import BertTokenizer 14 | from ..base_model import BaseModel 15 | from ...common.utils import is_url 16 | import torch.distributed as dist 17 | import torch.nn.functional as F 18 | import torch.nn as nn 19 | import contextlib 20 | import datetime 21 | import logging 22 | import torch 23 | import time 24 | import os 25 | 26 | class Blip2Base(BaseModel): 27 | @classmethod 28 | def init_tokenizer(cls): 29 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 30 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 31 | return tokenizer 32 | 33 | def maybe_autocast(self, dtype=torch.float16): 34 | # if on cpu, don't use autocast 35 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 36 | enable_autocast = self.device != torch.device("cpu") 37 | 38 | if enable_autocast: 39 | return torch.cuda.amp.autocast(dtype=dtype) 40 | else: 41 | return contextlib.nullcontext() 42 | 43 | @classmethod 44 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): 45 | encoder_config = BertConfig.from_pretrained("bert-base-uncased") 46 | encoder_config.encoder_width = vision_width 47 | # insert cross-attention layer every other block 48 | encoder_config.add_cross_attention = True 49 | encoder_config.cross_attention_freq = cross_attention_freq 50 | encoder_config.query_length = num_query_token 51 | Qformer = BertLMHeadModel.from_pretrained( 52 | "bert-base-uncased", config=encoder_config 53 | ) 54 | query_tokens = nn.Parameter( 55 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 56 | ) 57 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 58 | return Qformer, query_tokens 59 | 60 | @classmethod 61 | def init_vision_encoder( 62 | cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision 63 | ): 64 | assert model_name in [ 65 | "eva_clip_g", 66 | "clip_L", 67 | ], "vit model must be eva_clip_g or clip_L" 68 | if model_name == "eva_clip_g": 69 | visual_encoder = create_eva_vit_g( 70 | img_size, drop_path_rate, use_grad_checkpoint, precision 71 | ) 72 | elif model_name == "clip_L": 73 | visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision) 74 | ln_vision = LayerNorm(visual_encoder.num_features) 75 | return visual_encoder, ln_vision 76 | 77 | def load_from_pretrained(self, url_or_filename): 78 | if is_url(url_or_filename): 79 | cached_file = download_cached_file( 80 | url_or_filename, check_hash=False, progress=True 81 | ) 82 | checkpoint = torch.load(cached_file, map_location="cpu") 83 | elif os.path.isfile(url_or_filename): 84 | checkpoint = torch.load(url_or_filename, map_location="cpu") 85 | else: 86 | raise RuntimeError("checkpoint url or path is invalid") 87 | 88 | state_dict = checkpoint["model"] 89 | 90 | msg = self.load_state_dict(state_dict, strict=False) 91 | 92 | # logging.info("Missing keys {}".format(msg.missing_keys)) 93 | logging.info("load checkpoint from %s" % url_or_filename) 94 | 95 | return msg 96 | 97 | 98 | def disabled_train(self, mode=True): 99 | """Overwrite model.train with this function to make sure train/eval mode 100 | does not change anymore.""" 101 | return self 102 | 103 | 104 | class LayerNorm(nn.LayerNorm): 105 | """Subclass torch's LayerNorm to handle fp16.""" 106 | 107 | def forward(self, x: torch.Tensor): 108 | orig_type = x.dtype 109 | ret = super().forward(x.type(torch.float32)) 110 | return ret.type(orig_type) 111 | 112 | 113 | def compute_sim_matrix(model, data_loader, **kwargs): 114 | k_test = kwargs.pop("k_test") 115 | 116 | metric_logger = MetricLogger(delimiter=" ") 117 | header = "Evaluation:" 118 | 119 | logging.info("Computing features for evaluation...") 120 | start_time = time.time() 121 | 122 | texts = data_loader.dataset.text 123 | num_text = len(texts) 124 | text_bs = 256 125 | text_ids = [] 126 | text_embeds = [] 127 | text_atts = [] 128 | for i in range(0, num_text, text_bs): 129 | text = texts[i : min(num_text, i + text_bs)] 130 | text_input = model.tokenizer( 131 | text, 132 | padding="max_length", 133 | truncation=True, 134 | max_length=35, 135 | return_tensors="pt", 136 | ).to(model.device) 137 | text_feat = model.forward_text(text_input) 138 | text_embed = F.normalize(model.text_proj(text_feat)) 139 | text_embeds.append(text_embed) 140 | text_ids.append(text_input.input_ids) 141 | text_atts.append(text_input.attention_mask) 142 | 143 | text_embeds = torch.cat(text_embeds, dim=0) 144 | text_ids = torch.cat(text_ids, dim=0) 145 | text_atts = torch.cat(text_atts, dim=0) 146 | 147 | vit_feats = [] 148 | image_embeds = [] 149 | for samples in data_loader: 150 | image = samples["image"] 151 | 152 | image = image.to(model.device) 153 | image_feat, vit_feat = model.forward_image(image) 154 | image_embed = model.vision_proj(image_feat) 155 | image_embed = F.normalize(image_embed, dim=-1) 156 | 157 | vit_feats.append(vit_feat.cpu()) 158 | image_embeds.append(image_embed) 159 | 160 | vit_feats = torch.cat(vit_feats, dim=0) 161 | image_embeds = torch.cat(image_embeds, dim=0) 162 | 163 | sims_matrix = [] 164 | for image_embed in image_embeds: 165 | sim_q2t = image_embed @ text_embeds.t() 166 | sim_i2t, _ = sim_q2t.max(0) 167 | sims_matrix.append(sim_i2t) 168 | sims_matrix = torch.stack(sims_matrix, dim=0) 169 | 170 | score_matrix_i2t = torch.full( 171 | (len(data_loader.dataset.image), len(texts)), -100.0 172 | ).to(model.device) 173 | 174 | num_tasks = dist_utils.get_world_size() 175 | rank = dist_utils.get_rank() 176 | step = sims_matrix.size(0) // num_tasks + 1 177 | start = rank * step 178 | end = min(sims_matrix.size(0), start + step) 179 | 180 | for i, sims in enumerate( 181 | metric_logger.log_every(sims_matrix[start:end], 50, header) 182 | ): 183 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 184 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) 185 | score = model.compute_itm( 186 | image_inputs=image_inputs, 187 | text_ids=text_ids[topk_idx], 188 | text_atts=text_atts[topk_idx], 189 | ).float() 190 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim 191 | 192 | sims_matrix = sims_matrix.t() 193 | score_matrix_t2i = torch.full( 194 | (len(texts), len(data_loader.dataset.image)), -100.0 195 | ).to(model.device) 196 | 197 | step = sims_matrix.size(0) // num_tasks + 1 198 | start = rank * step 199 | end = min(sims_matrix.size(0), start + step) 200 | 201 | for i, sims in enumerate( 202 | metric_logger.log_every(sims_matrix[start:end], 50, header) 203 | ): 204 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 205 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device) 206 | score = model.compute_itm( 207 | image_inputs=image_inputs, 208 | text_ids=text_ids[start + i].repeat(k_test, 1), 209 | text_atts=text_atts[start + i].repeat(k_test, 1), 210 | ).float() 211 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim 212 | 213 | if dist_utils.is_dist_avail_and_initialized(): 214 | dist.barrier() 215 | torch.distributed.all_reduce( 216 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM 217 | ) 218 | torch.distributed.all_reduce( 219 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM 220 | ) 221 | 222 | total_time = time.time() - start_time 223 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 224 | logging.info("Evaluation time {}".format(total_time_str)) 225 | 226 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 227 | -------------------------------------------------------------------------------- /models/lavis/models/blip_models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from typing import List 9 | from torch import nn 10 | import logging 11 | 12 | 13 | def tie_encoder_decoder_weights( 14 | encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str 15 | ): 16 | uninitialized_encoder_weights: List[str] = [] 17 | if decoder.__class__ != encoder.__class__: 18 | logging.info( 19 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 20 | ) 21 | 22 | def tie_encoder_to_decoder_recursively( 23 | decoder_pointer: nn.Module, 24 | encoder_pointer: nn.Module, 25 | module_name: str, 26 | uninitialized_encoder_weights: List[str], 27 | skip_key: str, 28 | depth=0, 29 | ): 30 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 31 | encoder_pointer, nn.Module 32 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 33 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 34 | assert hasattr(encoder_pointer, "weight") 35 | encoder_pointer.weight = decoder_pointer.weight 36 | if hasattr(decoder_pointer, "bias"): 37 | assert hasattr(encoder_pointer, "bias") 38 | encoder_pointer.bias = decoder_pointer.bias 39 | print(module_name + " is tied") 40 | return 41 | 42 | encoder_modules = encoder_pointer._modules 43 | decoder_modules = decoder_pointer._modules 44 | if len(decoder_modules) > 0: 45 | assert ( 46 | len(encoder_modules) > 0 47 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 48 | 49 | all_encoder_weights = set( 50 | [module_name + "/" + sub_name for sub_name in encoder_modules.keys()] 51 | ) 52 | encoder_layer_pos = 0 53 | for name, module in decoder_modules.items(): 54 | if name.isdigit(): 55 | encoder_name = str(int(name) + encoder_layer_pos) 56 | decoder_name = name 57 | if not isinstance( 58 | decoder_modules[decoder_name], 59 | type(encoder_modules[encoder_name]), 60 | ) and len(encoder_modules) != len(decoder_modules): 61 | # this can happen if the name corresponds to the position in a list module list of layers 62 | # in this case the decoder has added a cross-attention that the encoder does not have 63 | # thus skip this step and subtract one layer pos from encoder 64 | encoder_layer_pos -= 1 65 | continue 66 | elif name not in encoder_modules: 67 | continue 68 | elif depth > 500: 69 | raise ValueError( 70 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 71 | ) 72 | else: 73 | decoder_name = encoder_name = name 74 | tie_encoder_to_decoder_recursively( 75 | decoder_modules[decoder_name], 76 | encoder_modules[encoder_name], 77 | module_name + "/" + name, 78 | uninitialized_encoder_weights, 79 | skip_key, 80 | depth=depth + 1, 81 | ) 82 | all_encoder_weights.remove(module_name + "/" + encoder_name) 83 | 84 | uninitialized_encoder_weights += list(all_encoder_weights) 85 | 86 | # tie weights recursively 87 | tie_encoder_to_decoder_recursively( 88 | decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key 89 | ) 90 | -------------------------------------------------------------------------------- /models/lavis/models/blip_models/blip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | from ...common.dist_utils import download_cached_file 10 | from ..vit import interpolate_pos_embed 11 | from transformers import BertTokenizer 12 | from ...common.utils import is_url 13 | from ..base_model import BaseModel 14 | from packaging import version 15 | import transformers 16 | import logging 17 | import torch 18 | import os 19 | 20 | class BlipBase(BaseModel): 21 | def __init__(self): 22 | super().__init__() 23 | transformers_version = version.parse(transformers.__version__) 24 | assert transformers_version < version.parse("4.27"), "BLIP models are not compatible with transformers>=4.27, run pip install transformers==4.25 to downgrade" 25 | 26 | @classmethod 27 | def init_tokenizer(cls): 28 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 29 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 30 | tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]}) 31 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 32 | return tokenizer 33 | 34 | def load_from_pretrained(self, url_or_filename): 35 | if is_url(url_or_filename): 36 | cached_file = download_cached_file( 37 | url_or_filename, check_hash=False, progress=True 38 | ) 39 | checkpoint = torch.load(cached_file, map_location="cpu") 40 | elif os.path.isfile(url_or_filename): 41 | checkpoint = torch.load(url_or_filename, map_location="cpu") 42 | else: 43 | raise RuntimeError("checkpoint url or path is invalid") 44 | 45 | state_dict = checkpoint["model"] 46 | 47 | state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( 48 | state_dict["visual_encoder.pos_embed"], self.visual_encoder 49 | ) 50 | if "visual_encoder_m.pos_embed" in self.state_dict().keys(): 51 | state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( 52 | state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m 53 | ) 54 | 55 | for key in self.state_dict().keys(): 56 | if key in state_dict.keys(): 57 | if state_dict[key].shape != self.state_dict()[key].shape: 58 | del state_dict[key] 59 | 60 | msg = self.load_state_dict(state_dict, strict=False) 61 | 62 | logging.info("Missing keys {}".format(msg.missing_keys)) 63 | logging.info("load checkpoint from %s" % url_or_filename) 64 | 65 | return msg 66 | -------------------------------------------------------------------------------- /models/lavis/models/blip_models/blip_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Optional 10 | 11 | import torch 12 | from transformers.modeling_outputs import ( 13 | ModelOutput, 14 | BaseModelOutputWithPoolingAndCrossAttentions, 15 | CausalLMOutputWithCrossAttentions, 16 | ) 17 | 18 | 19 | @dataclass 20 | class BlipSimilarity(ModelOutput): 21 | sim_i2t: torch.FloatTensor = None 22 | sim_t2i: torch.FloatTensor = None 23 | 24 | sim_i2t_m: Optional[torch.FloatTensor] = None 25 | sim_t2i_m: Optional[torch.FloatTensor] = None 26 | 27 | sim_i2t_targets: Optional[torch.FloatTensor] = None 28 | sim_t2i_targets: Optional[torch.FloatTensor] = None 29 | 30 | 31 | @dataclass 32 | class BlipIntermediateOutput(ModelOutput): 33 | """ 34 | Data class for intermediate outputs of BLIP models. 35 | 36 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). 37 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). 38 | 39 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). 40 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). 41 | 42 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. 43 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. 44 | 45 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. 46 | decoder_labels (torch.LongTensor): labels for the captioning loss. 47 | 48 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). 49 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) 50 | 51 | """ 52 | 53 | # uni-modal features 54 | image_embeds: torch.FloatTensor = None 55 | text_embeds: Optional[torch.FloatTensor] = None 56 | 57 | image_embeds_m: Optional[torch.FloatTensor] = None 58 | text_embeds_m: Optional[torch.FloatTensor] = None 59 | 60 | # intermediate outputs of multimodal encoder 61 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 62 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 63 | 64 | itm_logits: Optional[torch.FloatTensor] = None 65 | itm_labels: Optional[torch.LongTensor] = None 66 | 67 | # intermediate outputs of multimodal decoder 68 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None 69 | decoder_labels: Optional[torch.LongTensor] = None 70 | 71 | 72 | @dataclass 73 | class BlipOutput(ModelOutput): 74 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. 75 | sims: Optional[BlipSimilarity] = None 76 | 77 | intermediate_output: BlipIntermediateOutput = None 78 | 79 | loss: Optional[torch.FloatTensor] = None 80 | 81 | loss_itc: Optional[torch.FloatTensor] = None 82 | 83 | loss_itm: Optional[torch.FloatTensor] = None 84 | 85 | loss_lm: Optional[torch.FloatTensor] = None 86 | 87 | 88 | @dataclass 89 | class BlipOutputWithLogits(BlipOutput): 90 | logits: torch.FloatTensor = None 91 | logits_m: torch.FloatTensor = None 92 | 93 | 94 | @dataclass 95 | class BlipOutputFeatures(ModelOutput): 96 | """ 97 | Data class of features from BlipFeatureExtractor. 98 | 99 | Args: 100 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional 101 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional 102 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional 103 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional 104 | 105 | The first embedding or feature is for the [CLS] token. 106 | 107 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. 108 | """ 109 | 110 | image_embeds: Optional[torch.FloatTensor] = None 111 | image_embeds_proj: Optional[torch.FloatTensor] = None 112 | 113 | text_embeds: Optional[torch.FloatTensor] = None 114 | text_embeds_proj: Optional[torch.FloatTensor] = None 115 | 116 | multimodal_embeds: Optional[torch.FloatTensor] = None 117 | -------------------------------------------------------------------------------- /models/lavis/models/clip_vit.py: -------------------------------------------------------------------------------- 1 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 2 | from ..common.dist_utils import download_cached_file 3 | from .eva_vit import convert_weights_to_fp16 4 | from collections import OrderedDict 5 | import torch.nn.functional as F 6 | from itertools import repeat 7 | import collections.abc 8 | from torch import nn 9 | import torch 10 | import math 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1): 16 | super().__init__() 17 | 18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu1 = nn.ReLU(inplace=True) 22 | 23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.relu2 = nn.ReLU(inplace=True) 26 | 27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 28 | 29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 31 | self.relu3 = nn.ReLU(inplace=True) 32 | 33 | self.downsample = None 34 | self.stride = stride 35 | 36 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 38 | self.downsample = nn.Sequential(OrderedDict([ 39 | ("-1", nn.AvgPool2d(stride)), 40 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 41 | ("1", nn.BatchNorm2d(planes * self.expansion)) 42 | ])) 43 | 44 | def forward(self, x: torch.Tensor): 45 | identity = x 46 | 47 | out = self.relu1(self.bn1(self.conv1(x))) 48 | out = self.relu2(self.bn2(self.conv2(out))) 49 | out = self.avgpool(out) 50 | out = self.bn3(self.conv3(out)) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu3(out) 57 | return out 58 | 59 | 60 | class AttentionPool2d(nn.Module): 61 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 62 | super().__init__() 63 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 64 | self.k_proj = nn.Linear(embed_dim, embed_dim) 65 | self.q_proj = nn.Linear(embed_dim, embed_dim) 66 | self.v_proj = nn.Linear(embed_dim, embed_dim) 67 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 68 | self.num_heads = num_heads 69 | 70 | def forward(self, x): 71 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 72 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 73 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 74 | x, _ = F.multi_head_attention_forward( 75 | query=x, key=x, value=x, 76 | embed_dim_to_check=x.shape[-1], 77 | num_heads=self.num_heads, 78 | q_proj_weight=self.q_proj.weight, 79 | k_proj_weight=self.k_proj.weight, 80 | v_proj_weight=self.v_proj.weight, 81 | in_proj_weight=None, 82 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 83 | bias_k=None, 84 | bias_v=None, 85 | add_zero_attn=False, 86 | dropout_p=0, 87 | out_proj_weight=self.c_proj.weight, 88 | out_proj_bias=self.c_proj.bias, 89 | use_separate_proj_weight=True, 90 | training=self.training, 91 | need_weights=False 92 | ) 93 | 94 | return x[0] 95 | 96 | 97 | class LayerNorm(nn.LayerNorm): 98 | """Subclass torch's LayerNorm to handle fp16.""" 99 | 100 | def forward(self, x: torch.Tensor): 101 | orig_type = x.dtype 102 | ret = super().forward(x.type(torch.float32)) 103 | return ret.type(orig_type) 104 | 105 | 106 | class QuickGELU(nn.Module): 107 | def forward(self, x: torch.Tensor): 108 | return x * torch.sigmoid(1.702 * x) 109 | 110 | 111 | class ResidualAttentionBlock(nn.Module): 112 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): 113 | super().__init__() 114 | 115 | self.attn = nn.MultiheadAttention(d_model, n_head) 116 | self.ln_1 = LayerNorm(d_model) 117 | self.mlp = nn.Sequential(OrderedDict([ 118 | ("c_fc", nn.Linear(d_model, d_model * 4)), 119 | ("gelu", QuickGELU()), 120 | ("c_proj", nn.Linear(d_model * 4, d_model)) 121 | ])) 122 | self.ln_2 = LayerNorm(d_model) 123 | self.attn_mask = attn_mask 124 | 125 | if use_grad_checkpointing: 126 | self.attn = checkpoint_wrapper(self.attn) 127 | self.mlp = checkpoint_wrapper(self.mlp) 128 | 129 | def attention(self, x: torch.Tensor): 130 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 131 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 132 | 133 | def forward(self, x: torch.Tensor): 134 | x = x + self.attention(self.ln_1(x)) 135 | x = x + self.mlp(self.ln_2(x)) 136 | return x 137 | 138 | 139 | class Transformer(nn.Module): 140 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): 141 | super().__init__() 142 | self.width = width 143 | self.layers = layers 144 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) 145 | 146 | def forward(self, x: torch.Tensor): 147 | return self.resblocks(x) 148 | 149 | 150 | class VisionTransformer(nn.Module): 151 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): 152 | super().__init__() 153 | self.input_resolution = input_resolution 154 | self.num_features = width 155 | self.num_heads = heads 156 | self.num_patches = (input_resolution // patch_size) ** 2 157 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 158 | 159 | scale = width ** -0.5 160 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 161 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) 162 | self.ln_pre = LayerNorm(width) 163 | 164 | self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) 165 | 166 | # self.ln_final = LayerNorm(width) 167 | 168 | def forward(self, x: torch.Tensor): 169 | 170 | x = self.conv1(x) # shape = [*, width, grid, grid] 171 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 172 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 173 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 174 | x = x + self.positional_embedding.to(x.dtype) 175 | x = self.ln_pre(x) 176 | 177 | x = x.permute(1, 0, 2) # NLD -> LND 178 | x = self.transformer(x) 179 | x = x.permute(1, 0, 2) # LND -> NLD 180 | 181 | # x = self.ln_final(x) 182 | return x 183 | 184 | def get_num_layer(self, var_name=""): 185 | if var_name in ("class_embedding", "positional_embedding", "conv1", "ln_pre"): 186 | return 0 187 | elif var_name.startswith("transformer.resblocks"): 188 | layer_id = int(var_name.split('.')[2]) 189 | return layer_id + 1 190 | else: 191 | return len(self.transformer.resblocks) 192 | 193 | 194 | # From PyTorch internals 195 | def _ntuple(n): 196 | def parse(x): 197 | if isinstance(x, collections.abc.Iterable): 198 | return x 199 | return tuple(repeat(x, n)) 200 | return parse 201 | to_2tuple = _ntuple(2) 202 | 203 | def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1): 204 | # Rescale the grid of position embeddings when loading from state_dict 205 | old_pos_embed = state_dict.get('positional_embedding', None) 206 | 207 | grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5) 208 | if old_pos_embed is None: 209 | return 210 | grid_size = to_2tuple(grid_size) 211 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 212 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 213 | if new_seq_len == old_pos_embed.shape[0]: 214 | return 215 | 216 | if extra_tokens: 217 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 218 | else: 219 | pos_emb_tok, pos_emb_img = None, old_pos_embed 220 | 221 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 222 | 223 | print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 224 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 225 | pos_emb_img = F.interpolate( 226 | pos_emb_img, 227 | size=grid_size, 228 | mode=interpolation, 229 | align_corners=True, 230 | ) 231 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 232 | if pos_emb_tok is not None: 233 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 234 | else: 235 | new_pos_embed = pos_emb_img 236 | state_dict['positional_embedding'] = new_pos_embed 237 | 238 | 239 | def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"): 240 | model = VisionTransformer( 241 | input_resolution=img_size, 242 | patch_size=14, 243 | width=1024, 244 | layers=23, 245 | heads=16, 246 | use_grad_checkpointing=use_checkpoint, 247 | ) 248 | url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth" 249 | cached_file = download_cached_file( 250 | url, check_hash=False, progress=True 251 | ) 252 | state_dict = torch.load(cached_file, map_location="cpu") 253 | interpolate_pos_embed(model,state_dict) 254 | 255 | incompatible_keys = model.load_state_dict(state_dict, strict=False) 256 | # print(incompatible_keys) 257 | 258 | if precision == "fp16": 259 | convert_weights_to_fp16(model) 260 | return model 261 | -------------------------------------------------------------------------------- /models/lavis/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from ..processors.base_processor import BaseProcessor 9 | from ..processors.blip_processors import ( 10 | BlipImageTrainProcessor, 11 | Blip2ImageTrainProcessor, 12 | BlipImageEvalProcessor, 13 | BlipCaptionProcessor, 14 | ) 15 | 16 | from ..common.registry import registry 17 | 18 | __all__ = [ 19 | "BaseProcessor", 20 | # ALPRO 21 | "AlproVideoTrainProcessor", 22 | "AlproVideoEvalProcessor", 23 | # BLIP 24 | "BlipImageTrainProcessor", 25 | "Blip2ImageTrainProcessor", 26 | "BlipImageEvalProcessor", 27 | "BlipCaptionProcessor", 28 | "BlipInstructionProcessor", 29 | # BLIP-Diffusion 30 | "BlipDiffusionInputImageProcessor", 31 | "BlipDiffusionTargetImageProcessor", 32 | # CLIP 33 | "ClipImageTrainProcessor", 34 | # GPT 35 | "GPTVideoFeatureProcessor", 36 | "GPTDialogueProcessor", 37 | # AUDIO 38 | "BeatsAudioProcessor", 39 | # 3D 40 | "ULIPPCProcessor", 41 | ] 42 | 43 | 44 | def load_processor(name, cfg=None): 45 | """ 46 | Example 47 | 48 | >>> processor = load_processor("alpro_video_train", cfg=None) 49 | """ 50 | processor = registry.get_processor_class(name).from_config(cfg) 51 | 52 | return processor 53 | -------------------------------------------------------------------------------- /models/lavis/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /models/lavis/processors/blip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | from torchvision.transforms.functional import InterpolationMode 8 | from ..processors.base_processor import BaseProcessor 9 | from ..processors.randaugment import RandomAugment 10 | from ..common.registry import registry 11 | from torchvision import transforms 12 | from omegaconf import OmegaConf 13 | import re 14 | 15 | class BlipImageBaseProcessor(BaseProcessor): 16 | def __init__(self, mean=None, std=None): 17 | if mean is None: 18 | mean = (0.48145466, 0.4578275, 0.40821073) 19 | if std is None: 20 | std = (0.26862954, 0.26130258, 0.27577711) 21 | 22 | self.normalize = transforms.Normalize(mean, std) 23 | 24 | 25 | @registry.register_processor("blip_caption") 26 | class BlipCaptionProcessor(BaseProcessor): 27 | def __init__(self, prompt="", max_words=50): 28 | self.prompt = prompt 29 | self.max_words = max_words 30 | 31 | def __call__(self, caption): 32 | caption = self.prompt + self.pre_caption(caption) 33 | 34 | return caption 35 | 36 | @classmethod 37 | def from_config(cls, cfg=None): 38 | if cfg is None: 39 | cfg = OmegaConf.create() 40 | 41 | prompt = cfg.get("prompt", "") 42 | max_words = cfg.get("max_words", 50) 43 | 44 | return cls(prompt=prompt, max_words=max_words) 45 | 46 | def pre_caption(self, caption): 47 | caption = re.sub( 48 | r"([.!\"()*#:;~])", 49 | " ", 50 | caption.lower(), 51 | ) 52 | caption = re.sub( 53 | r"\s{2,}", 54 | " ", 55 | caption, 56 | ) 57 | caption = caption.rstrip("\n") 58 | caption = caption.strip(" ") 59 | 60 | # truncate caption 61 | caption_words = caption.split(" ") 62 | if len(caption_words) > self.max_words: 63 | caption = " ".join(caption_words[: self.max_words]) 64 | 65 | return caption 66 | 67 | 68 | @registry.register_processor("blip_question") 69 | class BlipQuestionProcessor(BaseProcessor): 70 | def __init__(self, max_words=50): 71 | self.max_words = max_words 72 | 73 | def __call__(self, question): 74 | return self.pre_question(question) 75 | 76 | @classmethod 77 | def from_config(cls, cfg=None): 78 | if cfg is None: 79 | cfg = OmegaConf.create() 80 | 81 | max_words = cfg.get("max_words", 50) 82 | 83 | return cls(max_words=max_words) 84 | 85 | def pre_question(self, question): 86 | question = re.sub( 87 | r"([.!\"()*#:;~])", 88 | "", 89 | question.lower(), 90 | ) 91 | question = question.rstrip(" ") 92 | 93 | # truncate question 94 | question_words = question.split(" ") 95 | if len(question_words) > self.max_words: 96 | question = " ".join(question_words[: self.max_words]) 97 | 98 | return question 99 | 100 | 101 | @registry.register_processor("blip_image_train") 102 | class BlipImageTrainProcessor(BlipImageBaseProcessor): 103 | def __init__( 104 | self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0 105 | ): 106 | super().__init__(mean=mean, std=std) 107 | 108 | self.transform = transforms.Compose( 109 | [ 110 | transforms.RandomResizedCrop( 111 | image_size, 112 | scale=(min_scale, max_scale), 113 | interpolation=InterpolationMode.BICUBIC, 114 | ), 115 | transforms.RandomHorizontalFlip(), 116 | RandomAugment( 117 | 2, 118 | 5, 119 | isPIL=True, 120 | augs=[ 121 | "Identity", 122 | "AutoContrast", 123 | "Brightness", 124 | "Sharpness", 125 | "Equalize", 126 | "ShearX", 127 | "ShearY", 128 | "TranslateX", 129 | "TranslateY", 130 | "Rotate", 131 | ], 132 | ), 133 | transforms.ToTensor(), 134 | self.normalize, 135 | ] 136 | ) 137 | 138 | def __call__(self, item): 139 | return self.transform(item) 140 | 141 | @classmethod 142 | def from_config(cls, cfg=None): 143 | if cfg is None: 144 | cfg = OmegaConf.create() 145 | 146 | image_size = cfg.get("image_size", 384) 147 | 148 | mean = cfg.get("mean", None) 149 | std = cfg.get("std", None) 150 | 151 | min_scale = cfg.get("min_scale", 0.5) 152 | max_scale = cfg.get("max_scale", 1.0) 153 | 154 | return cls( 155 | image_size=image_size, 156 | mean=mean, 157 | std=std, 158 | min_scale=min_scale, 159 | max_scale=max_scale, 160 | ) 161 | 162 | 163 | @registry.register_processor("blip_image_eval") 164 | class BlipImageEvalProcessor(BlipImageBaseProcessor): 165 | def __init__(self, image_size=384, mean=None, std=None): 166 | super().__init__(mean=mean, std=std) 167 | 168 | self.transform = transforms.Compose( 169 | [ 170 | transforms.Resize( 171 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 172 | ), 173 | transforms.ToTensor(), 174 | self.normalize, 175 | ] 176 | ) 177 | 178 | def __call__(self, item): 179 | return self.transform(item) 180 | 181 | @classmethod 182 | def from_config(cls, cfg=None): 183 | if cfg is None: 184 | cfg = OmegaConf.create() 185 | 186 | image_size = cfg.get("image_size", 384) 187 | 188 | mean = cfg.get("mean", None) 189 | std = cfg.get("std", None) 190 | 191 | return cls(image_size=image_size, mean=mean, std=std) 192 | 193 | 194 | @registry.register_processor("blip2_image_train") 195 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor): 196 | def __init__( 197 | self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0 198 | ): 199 | super().__init__(mean=mean, std=std) 200 | 201 | self.transform = transforms.Compose( 202 | [ 203 | transforms.RandomResizedCrop( 204 | image_size, 205 | scale=(min_scale, max_scale), 206 | interpolation=InterpolationMode.BICUBIC, 207 | ), 208 | transforms.RandomHorizontalFlip(), 209 | transforms.ToTensor(), 210 | self.normalize, 211 | ] 212 | ) 213 | 214 | def __call__(self, item): 215 | return self.transform(item) 216 | 217 | @classmethod 218 | def from_config(cls, cfg=None): 219 | if cfg is None: 220 | cfg = OmegaConf.create() 221 | 222 | image_size = cfg.get("image_size", 364) 223 | 224 | mean = cfg.get("mean", None) 225 | std = cfg.get("std", None) 226 | 227 | min_scale = cfg.get("min_scale", 0.5) 228 | max_scale = cfg.get("max_scale", 1.0) 229 | 230 | return cls( 231 | image_size=image_size, 232 | mean=mean, 233 | std=std, 234 | min_scale=min_scale, 235 | max_scale=max_scale, 236 | ) -------------------------------------------------------------------------------- /models/lavis/processors/clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from torchvision.transforms.functional import InterpolationMode 9 | from processors.blip_processors import BlipImageBaseProcessor 10 | from ..common.registry import registry 11 | from torchvision import transforms 12 | from omegaconf import OmegaConf 13 | 14 | 15 | def _convert_to_rgb(image): 16 | return image.convert("RGB") 17 | 18 | 19 | @registry.register_processor("clip_image_train") 20 | class ClipImageTrainProcessor(BlipImageBaseProcessor): 21 | def __init__( 22 | self, image_size=224, mean=None, std=None, min_scale=0.9, max_scale=1.0 23 | ): 24 | 25 | super().__init__(mean=mean, std=std) 26 | 27 | self.transform = transforms.Compose( 28 | [ 29 | transforms.RandomResizedCrop( 30 | image_size, 31 | scale=(min_scale, max_scale), 32 | interpolation=InterpolationMode.BICUBIC, 33 | ), 34 | _convert_to_rgb, 35 | transforms.ToTensor(), 36 | self.normalize, 37 | ] 38 | ) 39 | 40 | @classmethod 41 | def from_config(cls, cfg=None): 42 | if cfg is None: 43 | cfg = OmegaConf.create() 44 | 45 | image_size = cfg.get("image_size", 224) 46 | 47 | mean = cfg.get("mean", None) 48 | std = cfg.get("std", None) 49 | 50 | min_scale = cfg.get("min_scale", 0.9) 51 | max_scale = cfg.get("max_scale", 1.0) 52 | 53 | return cls( 54 | image_size=image_size, 55 | mean=mean, 56 | std=std, 57 | min_scale=min_scale, 58 | max_scale=max_scale, 59 | ) 60 | 61 | 62 | @registry.register_processor("clip_image_eval") 63 | class ClipImageEvalProcessor(BlipImageBaseProcessor): 64 | def __init__(self, image_size=224, mean=None, std=None): 65 | 66 | super().__init__(mean=mean, std=std) 67 | 68 | self.transform = transforms.Compose( 69 | [ 70 | transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC), 71 | transforms.CenterCrop(image_size), 72 | _convert_to_rgb, 73 | transforms.ToTensor(), 74 | self.normalize, 75 | ] 76 | ) 77 | 78 | @classmethod 79 | def from_config(cls, cfg=None): 80 | if cfg is None: 81 | cfg = OmegaConf.create() 82 | 83 | image_size = cfg.get("image_size", 224) 84 | 85 | mean = cfg.get("mean", None) 86 | std = cfg.get("std", None) 87 | 88 | return cls( 89 | image_size=image_size, 90 | mean=mean, 91 | std=std, 92 | ) 93 | -------------------------------------------------------------------------------- /models/lavis/processors/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import cv2 11 | 12 | 13 | ## aug functions 14 | def identity_func(img): 15 | return img 16 | 17 | 18 | def autocontrast_func(img, cutoff=0): 19 | """ 20 | same output as PIL.ImageOps.autocontrast 21 | """ 22 | n_bins = 256 23 | 24 | def tune_channel(ch): 25 | n = ch.size 26 | cut = cutoff * n // 100 27 | if cut == 0: 28 | high, low = ch.max(), ch.min() 29 | else: 30 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 31 | low = np.argwhere(np.cumsum(hist) > cut) 32 | low = 0 if low.shape[0] == 0 else low[0] 33 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 34 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 35 | if high <= low: 36 | table = np.arange(n_bins) 37 | else: 38 | scale = (n_bins - 1) / (high - low) 39 | offset = -low * scale 40 | table = np.arange(n_bins) * scale + offset 41 | table[table < 0] = 0 42 | table[table > n_bins - 1] = n_bins - 1 43 | table = table.clip(0, 255).astype(np.uint8) 44 | return table[ch] 45 | 46 | channels = [tune_channel(ch) for ch in cv2.split(img)] 47 | out = cv2.merge(channels) 48 | return out 49 | 50 | 51 | def equalize_func(img): 52 | """ 53 | same output as PIL.ImageOps.equalize 54 | PIL's implementation is different from cv2.equalize 55 | """ 56 | n_bins = 256 57 | 58 | def tune_channel(ch): 59 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 60 | non_zero_hist = hist[hist != 0].reshape(-1) 61 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 62 | if step == 0: 63 | return ch 64 | n = np.empty_like(hist) 65 | n[0] = step // 2 66 | n[1:] = hist[:-1] 67 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 68 | return table[ch] 69 | 70 | channels = [tune_channel(ch) for ch in cv2.split(img)] 71 | out = cv2.merge(channels) 72 | return out 73 | 74 | 75 | def rotate_func(img, degree, fill=(0, 0, 0)): 76 | """ 77 | like PIL, rotate by degree, not radians 78 | """ 79 | H, W = img.shape[0], img.shape[1] 80 | center = W / 2, H / 2 81 | M = cv2.getRotationMatrix2D(center, degree, 1) 82 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 83 | return out 84 | 85 | 86 | def solarize_func(img, thresh=128): 87 | """ 88 | same output as PIL.ImageOps.posterize 89 | """ 90 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 91 | table = table.clip(0, 255).astype(np.uint8) 92 | out = table[img] 93 | return out 94 | 95 | 96 | def color_func(img, factor): 97 | """ 98 | same output as PIL.ImageEnhance.Color 99 | """ 100 | ## implementation according to PIL definition, quite slow 101 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 102 | # out = blend(degenerate, img, factor) 103 | # M = ( 104 | # np.eye(3) * factor 105 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 106 | # )[np.newaxis, np.newaxis, :] 107 | M = np.float32( 108 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] 109 | ) * factor + np.float32([[0.114], [0.587], [0.299]]) 110 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 111 | return out 112 | 113 | 114 | def contrast_func(img, factor): 115 | """ 116 | same output as PIL.ImageEnhance.Contrast 117 | """ 118 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 119 | table = ( 120 | np.array([(el - mean) * factor + mean for el in range(256)]) 121 | .clip(0, 255) 122 | .astype(np.uint8) 123 | ) 124 | out = table[img] 125 | return out 126 | 127 | 128 | def brightness_func(img, factor): 129 | """ 130 | same output as PIL.ImageEnhance.Contrast 131 | """ 132 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 133 | out = table[img] 134 | return out 135 | 136 | 137 | def sharpness_func(img, factor): 138 | """ 139 | The differences the this result and PIL are all on the 4 boundaries, the center 140 | areas are same 141 | """ 142 | kernel = np.ones((3, 3), dtype=np.float32) 143 | kernel[1][1] = 5 144 | kernel /= 13 145 | degenerate = cv2.filter2D(img, -1, kernel) 146 | if factor == 0.0: 147 | out = degenerate 148 | elif factor == 1.0: 149 | out = img 150 | else: 151 | out = img.astype(np.float32) 152 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 153 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 154 | out = out.astype(np.uint8) 155 | return out 156 | 157 | 158 | def shear_x_func(img, factor, fill=(0, 0, 0)): 159 | H, W = img.shape[0], img.shape[1] 160 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 161 | out = cv2.warpAffine( 162 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 163 | ).astype(np.uint8) 164 | return out 165 | 166 | 167 | def translate_x_func(img, offset, fill=(0, 0, 0)): 168 | """ 169 | same output as PIL.Image.transform 170 | """ 171 | H, W = img.shape[0], img.shape[1] 172 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 173 | out = cv2.warpAffine( 174 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 175 | ).astype(np.uint8) 176 | return out 177 | 178 | 179 | def translate_y_func(img, offset, fill=(0, 0, 0)): 180 | """ 181 | same output as PIL.Image.transform 182 | """ 183 | H, W = img.shape[0], img.shape[1] 184 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 185 | out = cv2.warpAffine( 186 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 187 | ).astype(np.uint8) 188 | return out 189 | 190 | 191 | def posterize_func(img, bits): 192 | """ 193 | same output as PIL.ImageOps.posterize 194 | """ 195 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 196 | return out 197 | 198 | 199 | def shear_y_func(img, factor, fill=(0, 0, 0)): 200 | H, W = img.shape[0], img.shape[1] 201 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 202 | out = cv2.warpAffine( 203 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 204 | ).astype(np.uint8) 205 | return out 206 | 207 | 208 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 209 | replace = np.array(replace, dtype=np.uint8) 210 | H, W = img.shape[0], img.shape[1] 211 | rh, rw = np.random.random(2) 212 | pad_size = pad_size // 2 213 | ch, cw = int(rh * H), int(rw * W) 214 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 215 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 216 | out = img.copy() 217 | out[x1:x2, y1:y2, :] = replace 218 | return out 219 | 220 | 221 | ### level to args 222 | def enhance_level_to_args(MAX_LEVEL): 223 | def level_to_args(level): 224 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 225 | 226 | return level_to_args 227 | 228 | 229 | def shear_level_to_args(MAX_LEVEL, replace_value): 230 | def level_to_args(level): 231 | level = (level / MAX_LEVEL) * 0.3 232 | if np.random.random() > 0.5: 233 | level = -level 234 | return (level, replace_value) 235 | 236 | return level_to_args 237 | 238 | 239 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 240 | def level_to_args(level): 241 | level = (level / MAX_LEVEL) * float(translate_const) 242 | if np.random.random() > 0.5: 243 | level = -level 244 | return (level, replace_value) 245 | 246 | return level_to_args 247 | 248 | 249 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 250 | def level_to_args(level): 251 | level = int((level / MAX_LEVEL) * cutout_const) 252 | return (level, replace_value) 253 | 254 | return level_to_args 255 | 256 | 257 | def solarize_level_to_args(MAX_LEVEL): 258 | def level_to_args(level): 259 | level = int((level / MAX_LEVEL) * 256) 260 | return (level,) 261 | 262 | return level_to_args 263 | 264 | 265 | def none_level_to_args(level): 266 | return () 267 | 268 | 269 | def posterize_level_to_args(MAX_LEVEL): 270 | def level_to_args(level): 271 | level = int((level / MAX_LEVEL) * 4) 272 | return (level,) 273 | 274 | return level_to_args 275 | 276 | 277 | def rotate_level_to_args(MAX_LEVEL, replace_value): 278 | def level_to_args(level): 279 | level = (level / MAX_LEVEL) * 30 280 | if np.random.random() < 0.5: 281 | level = -level 282 | return (level, replace_value) 283 | 284 | return level_to_args 285 | 286 | 287 | func_dict = { 288 | "Identity": identity_func, 289 | "AutoContrast": autocontrast_func, 290 | "Equalize": equalize_func, 291 | "Rotate": rotate_func, 292 | "Solarize": solarize_func, 293 | "Color": color_func, 294 | "Contrast": contrast_func, 295 | "Brightness": brightness_func, 296 | "Sharpness": sharpness_func, 297 | "ShearX": shear_x_func, 298 | "TranslateX": translate_x_func, 299 | "TranslateY": translate_y_func, 300 | "Posterize": posterize_func, 301 | "ShearY": shear_y_func, 302 | } 303 | 304 | translate_const = 10 305 | MAX_LEVEL = 10 306 | replace_value = (128, 128, 128) 307 | arg_dict = { 308 | "Identity": none_level_to_args, 309 | "AutoContrast": none_level_to_args, 310 | "Equalize": none_level_to_args, 311 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), 312 | "Solarize": solarize_level_to_args(MAX_LEVEL), 313 | "Color": enhance_level_to_args(MAX_LEVEL), 314 | "Contrast": enhance_level_to_args(MAX_LEVEL), 315 | "Brightness": enhance_level_to_args(MAX_LEVEL), 316 | "Sharpness": enhance_level_to_args(MAX_LEVEL), 317 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), 318 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 319 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 320 | "Posterize": posterize_level_to_args(MAX_LEVEL), 321 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), 322 | } 323 | 324 | 325 | class RandomAugment(object): 326 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 327 | self.N = N 328 | self.M = M 329 | self.isPIL = isPIL 330 | if augs: 331 | self.augs = augs 332 | else: 333 | self.augs = list(arg_dict.keys()) 334 | 335 | def get_random_ops(self): 336 | sampled_ops = np.random.choice(self.augs, self.N) 337 | return [(op, 0.5, self.M) for op in sampled_ops] 338 | 339 | def __call__(self, img): 340 | if self.isPIL: 341 | img = np.array(img) 342 | ops = self.get_random_ops() 343 | for name, prob, level in ops: 344 | if np.random.random() > prob: 345 | continue 346 | args = arg_dict[name](level) 347 | img = func_dict[name](img, *args) 348 | return img 349 | 350 | 351 | class VideoRandomAugment(object): 352 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): 353 | self.N = N 354 | self.M = M 355 | self.p = p 356 | self.tensor_in_tensor_out = tensor_in_tensor_out 357 | if augs: 358 | self.augs = augs 359 | else: 360 | self.augs = list(arg_dict.keys()) 361 | 362 | def get_random_ops(self): 363 | sampled_ops = np.random.choice(self.augs, self.N, replace=False) 364 | return [(op, self.M) for op in sampled_ops] 365 | 366 | def __call__(self, frames): 367 | assert ( 368 | frames.shape[-1] == 3 369 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." 370 | 371 | if self.tensor_in_tensor_out: 372 | frames = frames.numpy().astype(np.uint8) 373 | 374 | num_frames = frames.shape[0] 375 | 376 | ops = num_frames * [self.get_random_ops()] 377 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] 378 | 379 | frames = torch.stack( 380 | list(map(self._aug, frames, ops, apply_or_not)), dim=0 381 | ).float() 382 | 383 | return frames 384 | 385 | def _aug(self, img, ops, apply_or_not): 386 | for i, (name, level) in enumerate(ops): 387 | if not apply_or_not[i]: 388 | continue 389 | args = arg_dict[name](level) 390 | img = func_dict[name](img, *args) 391 | return torch.from_numpy(img) 392 | 393 | 394 | if __name__ == "__main__": 395 | a = RandomAugment() 396 | img = np.random.randn(32, 32, 3) 397 | a(img) 398 | -------------------------------------------------------------------------------- /models/lavis/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from .runner_base import RunnerBase 9 | from .runner_iter import RunnerIter 10 | 11 | __all__ = ["RunnerBase", "RunnerIter"] 12 | -------------------------------------------------------------------------------- /models/lavis/runners/runner_iter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import os 11 | import time 12 | 13 | import torch 14 | import torch.distributed as dist 15 | import webdataset as wds 16 | from lavis.common.dist_utils import download_cached_file, is_main_process, main_process 17 | from lavis.common.registry import registry 18 | from lavis.common.utils import is_url 19 | from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split 20 | from lavis.runners.runner_base import RunnerBase 21 | from torch.utils.data.dataset import ChainDataset 22 | 23 | 24 | @registry.register_runner("runner_iter") 25 | class RunnerIter(RunnerBase): 26 | """ 27 | Run training based on the number of iterations. This is common when 28 | the training dataset size is large. Underhood logic is similar to 29 | epoch-based training by considering every #iters_per_inner_epoch as an 30 | inner epoch. 31 | 32 | In iter-based runner, after every #iters_per_inner_epoch steps, we 33 | 34 | 1) do a validation epoch; 35 | 2) schedule the learning rate; 36 | 3) save the checkpoint. 37 | 38 | We refer every #iters_per_inner_epoch steps as an inner epoch. 39 | """ 40 | 41 | def __init__(self, cfg, task, model, datasets, job_id): 42 | super().__init__(cfg, task, model, datasets, job_id) 43 | 44 | self.start_iters = 0 45 | 46 | self.max_iters = int(self.config.run_cfg.get("max_iters", -1)) 47 | assert self.max_iters > 0, "max_iters must be greater than 0." 48 | 49 | self.iters_per_inner_epoch = int( 50 | self.config.run_cfg.get("iters_per_inner_epoch", -1) 51 | ) 52 | assert ( 53 | self.iters_per_inner_epoch > 0 54 | ), "iters_per_inner_epoch must be greater than 0." 55 | 56 | @property 57 | def max_epoch(self): 58 | return int(self.max_iters / self.iters_per_inner_epoch) 59 | 60 | @property 61 | def cur_epoch(self): 62 | try: 63 | return self.train_loader.epoch 64 | except AttributeError: 65 | # pipeline data (e.g. LAION) is streaming, have no concept of epoch 66 | return 0 67 | 68 | def _progress(self, cur_iters): 69 | return "{}_iters={}".format(self.cur_epoch, cur_iters) 70 | 71 | def train(self): 72 | start_time = time.time() 73 | best_agg_metric = 0 74 | best_iters = 0 75 | 76 | self.log_config() 77 | 78 | # resume from checkpoint if specified 79 | if not self.evaluate_only and self.resume_ckpt_path is not None: 80 | self._load_checkpoint(self.resume_ckpt_path) 81 | 82 | for start_iters in range( 83 | self.start_iters, self.max_iters, self.iters_per_inner_epoch 84 | ): 85 | end_iters = start_iters + self.iters_per_inner_epoch 86 | 87 | # training phase 88 | if not self.evaluate_only: 89 | logging.info( 90 | "Start training, max_iters={}, in total {} inner epochs.".format( 91 | self.max_iters, int(self.max_iters / self.iters_per_inner_epoch) 92 | ) 93 | ) 94 | if start_iters == self.start_iters: 95 | self.task.before_training( 96 | model=self.unwrap_dist_model(self.model), 97 | dataset=self.datasets, 98 | ) 99 | train_stats = self.train_iters(self.cur_epoch, start_iters) 100 | self.log_stats(split_name="train", stats=train_stats) 101 | 102 | # evaluation phase 103 | if len(self.valid_splits) > 0 and (self.evaluate_only or (end_iters//self.iters_per_inner_epoch)%self.val_freq == 0): 104 | for split_name in self.valid_splits: 105 | logging.info("Evaluating on {}.".format(split_name)) 106 | 107 | val_log = self.eval_epoch( 108 | split_name=split_name, cur_epoch=self._progress(end_iters) 109 | ) 110 | if val_log is not None: 111 | if is_main_process(): 112 | assert ( 113 | "agg_metrics" in val_log 114 | ), "No agg_metrics found in validation log." 115 | 116 | agg_metrics = val_log["agg_metrics"] 117 | if agg_metrics > best_agg_metric and split_name == "val": 118 | best_iters, best_agg_metric = end_iters, agg_metrics 119 | 120 | self._save_checkpoint(end_iters, is_best=True) 121 | 122 | val_log.update({"best_iters": best_iters}) 123 | self.log_stats(val_log, split_name) 124 | 125 | else: 126 | # if no validation split is provided, we just save the checkpoint at the end of each inner epoch. 127 | if not self.evaluate_only: 128 | self._save_checkpoint(end_iters, is_best=False) 129 | 130 | if self.evaluate_only: 131 | break 132 | 133 | # save checkpoint according to save freq 134 | # if self.save_freq>0 and (end_iters//self.iters_per_inner_epoch)%self.save_freq == 0: 135 | self._save_checkpoint(end_iters, is_best=False) 136 | 137 | dist.barrier() 138 | 139 | # save last checkpoint 140 | if self.save_last and not self.evaluate_only: 141 | self._save_checkpoint(end_iters, is_best=False) 142 | 143 | # testing phase 144 | self.evaluate(cur_epoch=self.cur_epoch) 145 | 146 | total_time = time.time() - start_time 147 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 148 | logging.info("Training time {}".format(total_time_str)) 149 | 150 | def train_iters(self, epoch, start_iters): 151 | # train by iterations 152 | self.model.train() 153 | 154 | return self.task.train_iters( 155 | epoch=epoch, 156 | start_iters=start_iters, 157 | iters_per_inner_epoch=self.iters_per_inner_epoch, 158 | model=self.model, 159 | data_loader=self.train_loader, 160 | optimizer=self.optimizer, 161 | scaler=self.scaler, 162 | lr_scheduler=self.lr_scheduler, 163 | cuda_enabled=self.cuda_enabled, 164 | log_freq=self.log_freq, 165 | accum_grad_iters=self.accum_grad_iters, 166 | ) 167 | 168 | @main_process 169 | def _save_checkpoint(self, cur_iters, is_best=False, is_last=False): 170 | model_no_ddp = self.unwrap_dist_model(self.model) 171 | param_grad_dic = { 172 | k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() 173 | } 174 | 175 | state_dict = model_no_ddp.state_dict() 176 | for k in list(state_dict.keys()): 177 | if k in param_grad_dic.keys() and not param_grad_dic[k]: 178 | # delete parameters that do not require gradient 179 | del state_dict[k] 180 | 181 | save_obj = { 182 | "model": state_dict, 183 | "optimizer": self.optimizer.state_dict(), 184 | "config": self.config.to_dict(), 185 | "scaler": self.scaler.state_dict() if self.scaler else None, 186 | "iters": cur_iters, 187 | } 188 | save_to = os.path.join( 189 | self.output_dir, 190 | "checkpoint_{}.pth".format("best" if is_best else cur_iters), 191 | ) 192 | logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to)) 193 | torch.save(save_obj, save_to) 194 | 195 | def _load_checkpoint(self, url_or_filename): 196 | """ 197 | Resume from a checkpoint. 198 | """ 199 | if is_url(url_or_filename): 200 | cached_file = download_cached_file( 201 | url_or_filename, check_hash=False, progress=True 202 | ) 203 | checkpoint = torch.load(cached_file, map_location=self.device) 204 | elif os.path.isfile(url_or_filename): 205 | checkpoint = torch.load(url_or_filename, map_location=self.device) 206 | else: 207 | raise RuntimeError("checkpoint url or path is invalid") 208 | 209 | state_dict = checkpoint["model"] 210 | self.unwrap_dist_model(self.model).load_state_dict(state_dict) 211 | 212 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 213 | if self.scaler and "scaler" in checkpoint: 214 | self.scaler.load_state_dict(checkpoint["scaler"]) 215 | 216 | self.start_iters = checkpoint["iters"] + 1 217 | logging.info("Resume checkpoint from {}".format(url_or_filename)) 218 | 219 | @property 220 | def dataloaders(self) -> dict: 221 | """ 222 | A property to get and create dataloaders by split just in need. 223 | 224 | If no train_dataset_ratio is provided, concatenate map-style datasets and 225 | chain wds.DataPipe datasets separately. Training set becomes a tuple 226 | (ConcatDataset, ChainDataset), both are optional but at least one of them is 227 | required. The resultant ConcatDataset and ChainDataset will be sampled evenly. 228 | 229 | If train_dataset_ratio is provided, create a MultiIterLoader to sample 230 | each dataset by ratios during training. 231 | 232 | Currently do not support multiple datasets for validation and test. 233 | 234 | Returns: 235 | dict: {split_name: (tuples of) dataloader} 236 | """ 237 | if self._dataloaders is None: 238 | # reoganize datasets by split and concatenate/chain if necessary 239 | dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None) 240 | 241 | if dataset_ratios is None: 242 | # concatenate map-style datasets and chain wds.DataPipe datasets separately 243 | # training set becomes a tuple (ConcatDataset, ChainDataset), both are 244 | # optional but at least one of them is required. The resultant ConcatDataset 245 | # and ChainDataset will be sampled evenly. 246 | logging.info( 247 | "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." 248 | ) 249 | 250 | datasets = reorg_datasets_by_split(self.datasets) 251 | self.datasets = concat_datasets(datasets) 252 | else: 253 | # create multi-loader with the provided ratios, without concatenating or chaining 254 | missing_keys = [k for k in dataset_ratios if k not in self.datasets] 255 | if len(missing_keys) > 0: 256 | raise ValueError( 257 | "Datasets with the following split names are not found: {}".format( 258 | missing_keys 259 | ) 260 | ) 261 | 262 | unexpected_keys = [k for k in self.datasets if k not in dataset_ratios] 263 | if len(unexpected_keys) > 0: 264 | raise ValueError( 265 | "Datasets with the following split names are not expected: {}".format( 266 | unexpected_keys 267 | ) 268 | ) 269 | 270 | dataset_ratios = [float(dataset_ratios[k]) for k in self.datasets] 271 | self.datasets = reorg_datasets_by_split(self.datasets) 272 | # to keep the same structure as return value of concat_datasets 273 | self.datasets = { 274 | k: v[0] if len(v) == 1 else v for k, v in self.datasets.items() 275 | } 276 | 277 | # print dataset statistics after concatenation/chaining 278 | for split_name in self.datasets: 279 | if isinstance(self.datasets[split_name], tuple) or isinstance( 280 | self.datasets[split_name], list 281 | ): 282 | # mixed wds.DataPipeline and torch.utils.data.Dataset 283 | num_records = sum( 284 | [ 285 | len(d) 286 | if not type(d) in [wds.DataPipeline, ChainDataset] 287 | else 0 288 | for d in self.datasets[split_name] 289 | ] 290 | ) 291 | 292 | else: 293 | try: 294 | # a single map-style dataset 295 | num_records = len(self.datasets[split_name]) 296 | except TypeError: 297 | # a single wds.DataPipeline or ChainDataset 298 | num_records = -1 299 | logging.info( 300 | "Only a single wds.DataPipeline dataset, no __len__ attribute." 301 | ) 302 | 303 | if num_records >= 0: 304 | logging.info( 305 | "Loaded {} records for {} split from the dataset.".format( 306 | num_records, split_name 307 | ) 308 | ) 309 | 310 | # create dataloaders 311 | split_names = sorted(self.datasets.keys()) 312 | 313 | datasets = [self.datasets[split] for split in split_names] 314 | is_trains = [split in self.train_splits for split in split_names] 315 | 316 | batch_sizes = [ 317 | self.config.run_cfg.batch_size_train 318 | if split == "train" 319 | else self.config.run_cfg.batch_size_eval 320 | for split in split_names 321 | ] 322 | 323 | collate_fns = [] 324 | for dataset in datasets: 325 | if isinstance(dataset, tuple) or isinstance(dataset, list): 326 | collate_fns.append([getattr(d, "collater", None) for d in dataset]) 327 | else: 328 | collate_fns.append(getattr(dataset, "collater", None)) 329 | 330 | dataloaders = self.create_loaders( 331 | datasets=datasets, 332 | num_workers=self.config.run_cfg.num_workers, 333 | batch_sizes=batch_sizes, 334 | is_trains=is_trains, 335 | collate_fns=collate_fns, 336 | dataset_ratios=dataset_ratios, 337 | ) 338 | 339 | self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} 340 | 341 | return self._dataloaders 342 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | from util.misc import NestedTensor 2 | from torch import nn 3 | import torch 4 | import math 5 | 6 | class PositionEmbeddingSine(nn.Module): 7 | """ 8 | This is a more standard version of the position embedding, very similar to the one 9 | used by the Attention is all you need paper, generalized to work on images. 10 | """ 11 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 12 | super().__init__() 13 | self.num_pos_feats = num_pos_feats 14 | self.temperature = temperature 15 | self.normalize = normalize 16 | if scale is not None and normalize is False: 17 | raise ValueError("normalize should be True if scale is passed") 18 | if scale is None: 19 | scale = 2 * math.pi 20 | self.scale = scale 21 | 22 | def forward(self, tensor_list: NestedTensor): 23 | x = tensor_list.tensors 24 | mask = tensor_list.mask 25 | assert mask is not None 26 | not_mask = ~mask 27 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 28 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 29 | if self.normalize: 30 | eps = 1e-6 31 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 32 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 33 | 34 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 35 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 36 | 37 | pos_x = x_embed[:, :, :, None] / dim_t 38 | pos_y = y_embed[:, :, :, None] / dim_t 39 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 40 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 41 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 42 | return pos 43 | 44 | class PositionEmbeddingLearned(nn.Module): 45 | """ 46 | Absolute pos embedding, learned. 47 | """ 48 | def __init__(self, num_pos_feats=256): 49 | super().__init__() 50 | self.row_embed = nn.Embedding(50, num_pos_feats) 51 | self.col_embed = nn.Embedding(50, num_pos_feats) 52 | self.reset_parameters() 53 | 54 | def reset_parameters(self): 55 | nn.init.uniform_(self.row_embed.weight) 56 | nn.init.uniform_(self.col_embed.weight) 57 | 58 | def forward(self, tensor_list: NestedTensor): 59 | if torch.is_tensor(tensor_list): 60 | x = tensor_list 61 | else: 62 | x = tensor_list.tensors 63 | h, w = x.shape[-2:] 64 | i = torch.arange(w, device=x.device) 65 | j = torch.arange(h, device=x.device) 66 | x_emb = self.col_embed(i) 67 | y_emb = self.row_embed(j) 68 | pos = torch.cat([ 69 | x_emb.unsqueeze(0).repeat(h, 1, 1), 70 | y_emb.unsqueeze(1).repeat(1, w, 1), 71 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 72 | return pos 73 | 74 | def build_position_encoding(args): 75 | N_steps = args.hidden_dim // 2 76 | if args.position_embedding in ('v2', 'sine'): 77 | # TODO find a better way of exposing other arguments 78 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 79 | elif args.position_embedding in ('v3', 'learned'): 80 | position_embedding = PositionEmbeddingLearned(N_steps) 81 | else: 82 | raise ValueError(f"not supported {args.position_embedding}") 83 | 84 | return position_embedding 85 | 86 | class Positionalencoding1d(nn.Module): 87 | def __init__(self, d_model, length): 88 | super().__init__() 89 | self.d_model = d_model 90 | self.length = length 91 | 92 | def __call__(self): 93 | if self.d_model % 2 != 0: 94 | raise ValueError("Cannot use sin/cos positional encoding with " 95 | "odd dim (got dim={:d})".format(self.d_model)) 96 | pe = torch.zeros(self.length, self.d_model) 97 | position = torch.arange(0, self.length).unsqueeze(1) 98 | div_term = torch.exp((torch.arange(0, self.d_model, 2, dtype=torch.float) * 99 | -(math.log(10000.0) / self.d_model))) 100 | pe[:, 0::2] = torch.sin(position.float() * div_term) 101 | pe[:, 1::2] = torch.cos(position.float() * div_term) 102 | return pe 103 | 104 | class LearnablePositionalEncoding1D(nn.Module): 105 | def __init__(self, input_size, max_len=1000): 106 | super(LearnablePositionalEncoding1D, self).__init__() 107 | 108 | # Embedding layer for positional encoding 109 | self.positional_encoding = nn.Embedding(max_len, input_size) 110 | 111 | def forward(self, x): 112 | """ 113 | Args: 114 | x: Input tensor of shape (batch_size, seq_len, input_size) 115 | Returns: 116 | Positionally encoded tensor of the same shape as input 117 | """ 118 | batch_size, seq_len, _ = x.size() 119 | 120 | # Generate positional indices 121 | positions = torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1).to(x.device) 122 | 123 | # Get positional embeddings 124 | positional_embeddings = self.positional_encoding(positions) 125 | 126 | return positional_embeddings -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn, Tensor 3 | from typing import Optional 4 | import copy 5 | import torch 6 | 7 | class Transformer(nn.Module): 8 | 9 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=0, 10 | num_decoder_layers=3, dim_feedforward=2048, dropout=0.1, 11 | activation="relu", normalize_before=False, 12 | return_intermediate_dec=False, args=None): 13 | super().__init__() 14 | 15 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 16 | dropout, activation, normalize_before) 17 | 18 | decoder_norm = nn.LayerNorm(d_model) 19 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec) 20 | self._reset_parameters() 21 | self.d_model = d_model 22 | self.nhead = nhead 23 | 24 | def _reset_parameters(self): 25 | for p in self.parameters(): 26 | if p.dim() > 1: 27 | nn.init.xavier_uniform_(p) 28 | 29 | def forward(self, src, query_embed, pos_embed, init_query=None, mask=None, tgt_mask=None): 30 | bs, c, h, w = src.shape 31 | src = src.flatten(2).permute(2, 0, 1) 32 | query_embed = query_embed.permute(1,0,2) 33 | tgt = torch.zeros_like(query_embed) 34 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 35 | hs = self.decoder(tgt, src, tgt_mask=tgt_mask, memory_key_padding_mask=mask, 36 | pos=pos_embed, query_pos=query_embed) 37 | return hs.transpose(1, 2), src.permute(1, 2, 0).view(bs, c, h, w) 38 | 39 | 40 | 41 | class TransformerEncoder(nn.Module): 42 | 43 | def __init__(self, encoder_layer, num_layers, norm=None): 44 | super().__init__() 45 | self.layers = _get_clones(encoder_layer, num_layers) 46 | self.num_layers = num_layers 47 | self.norm = norm 48 | 49 | def forward(self, src, 50 | mask: Optional[Tensor] = None, 51 | src_key_padding_mask: Optional[Tensor] = None, 52 | pos: Optional[Tensor] = None): 53 | output = src 54 | 55 | for layer in self.layers: 56 | output = layer(output, src_mask=mask, 57 | src_key_padding_mask=src_key_padding_mask, pos=pos) 58 | 59 | if self.norm is not None: 60 | output = self.norm(output) 61 | 62 | return output 63 | 64 | 65 | class TransformerDecoder(nn.Module): 66 | 67 | def __init__(self, decoder_layer, num_layers, norm=None, decoder_last_norm=None, return_intermediate=False): 68 | super().__init__() 69 | self.layers = _get_clones(decoder_layer, num_layers) 70 | self.num_layers = num_layers 71 | self.norm = norm 72 | self.return_intermediate = return_intermediate 73 | 74 | def forward(self, tgt, memory, 75 | tgt_mask: Optional[Tensor] = None, 76 | memory_mask: Optional[Tensor] = None, 77 | tgt_key_padding_mask: Optional[Tensor] = None, 78 | memory_key_padding_mask: Optional[Tensor] = None, 79 | pos: Optional[Tensor] = None, 80 | query_pos: Optional[Tensor] = None): 81 | output = tgt 82 | intermediate = [] 83 | for layer in self.layers: 84 | output = layer(output, memory, tgt_mask=tgt_mask, 85 | memory_mask=memory_mask, 86 | tgt_key_padding_mask=tgt_key_padding_mask, 87 | memory_key_padding_mask=memory_key_padding_mask, 88 | pos=pos, query_pos=query_pos) 89 | 90 | if self.return_intermediate: 91 | intermediate.append(self.norm(output)) 92 | 93 | if self.norm is not None: 94 | output = self.norm(output) 95 | if self.return_intermediate: 96 | intermediate.pop() 97 | intermediate.append(output) 98 | 99 | if self.return_intermediate: 100 | return torch.stack(intermediate) 101 | 102 | return output.unsqueeze(0) 103 | 104 | 105 | class TransformerEncoderLayer(nn.Module): 106 | 107 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 108 | activation="relu", normalize_before=False): 109 | super().__init__() 110 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 111 | 112 | self.linear1 = nn.Linear(d_model, dim_feedforward) 113 | self.dropout = nn.Dropout(dropout) 114 | self.linear2 = nn.Linear(dim_feedforward, d_model) 115 | 116 | self.norm1 = nn.LayerNorm(d_model) 117 | self.norm2 = nn.LayerNorm(d_model) 118 | self.dropout1 = nn.Dropout(dropout) 119 | self.dropout2 = nn.Dropout(dropout) 120 | 121 | self.activation = _get_activation_fn(activation) 122 | self.normalize_before = normalize_before 123 | 124 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 125 | return tensor if pos is None else tensor + pos 126 | 127 | def forward_post(self, 128 | src, 129 | src_mask: Optional[Tensor] = None, 130 | src_key_padding_mask: Optional[Tensor] = None, 131 | pos: Optional[Tensor] = None): 132 | q = k = self.with_pos_embed(src, pos) 133 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 134 | key_padding_mask=src_key_padding_mask)[0] 135 | src = src + self.dropout1(src2) 136 | src = self.norm1(src) 137 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 138 | src = src + self.dropout2(src2) 139 | src = self.norm2(src) 140 | return src 141 | 142 | def forward_pre(self, src, 143 | src_mask: Optional[Tensor] = None, 144 | src_key_padding_mask: Optional[Tensor] = None, 145 | pos: Optional[Tensor] = None): 146 | src2 = self.norm1(src) 147 | q = k = self.with_pos_embed(src2, pos) 148 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 149 | key_padding_mask=src_key_padding_mask)[0] 150 | src = src + self.dropout1(src2) 151 | src2 = self.norm2(src) 152 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 153 | src = src + self.dropout2(src2) 154 | return src 155 | 156 | def forward(self, src, 157 | src_mask: Optional[Tensor] = None, 158 | src_key_padding_mask: Optional[Tensor] = None, 159 | pos: Optional[Tensor] = None): 160 | 161 | if self.normalize_before: 162 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 163 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 164 | 165 | 166 | 167 | 168 | class TransformerDecoderLayer(nn.Module): 169 | 170 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 171 | activation="relu", normalize_before=False): 172 | super().__init__() 173 | 174 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 175 | # Implementation of Feedforward model 176 | self.linear1 = nn.Linear(d_model, dim_feedforward) 177 | self.dropout = nn.Dropout(dropout) 178 | self.linear2 = nn.Linear(dim_feedforward, d_model) 179 | 180 | self.norm1 = nn.LayerNorm(d_model) 181 | self.norm2 = nn.LayerNorm(d_model) 182 | self.norm3 = nn.LayerNorm(d_model) 183 | 184 | self.dropout1 = nn.Dropout(dropout) 185 | self.dropout2 = nn.Dropout(dropout) 186 | self.dropout3 = nn.Dropout(dropout) 187 | 188 | self.activation = _get_activation_fn(activation) 189 | self.normalize_before = normalize_before 190 | 191 | 192 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 193 | return tensor if pos is None else tensor + pos 194 | 195 | def forward_post(self, tgt, memory, 196 | tgt_mask: Optional[Tensor] = None, 197 | memory_mask: Optional[Tensor] = None, 198 | tgt_key_padding_mask: Optional[Tensor] = None, 199 | memory_key_padding_mask: Optional[Tensor] = None, 200 | pos: Optional[Tensor] = None, 201 | query_pos: Optional[Tensor] = None): 202 | 203 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 204 | key=self.with_pos_embed(memory, pos), 205 | value=memory, attn_mask=memory_mask, 206 | key_padding_mask=memory_key_padding_mask)[0] 207 | tgt = tgt + self.dropout2(tgt2) 208 | tgt = self.norm2(tgt) 209 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 210 | tgt = tgt + self.dropout3(tgt2) 211 | tgt = self.norm3(tgt) 212 | return tgt 213 | 214 | def forward_pre(self, tgt, memory, 215 | tgt_mask: Optional[Tensor] = None, 216 | memory_mask: Optional[Tensor] = None, 217 | tgt_key_padding_mask: Optional[Tensor] = None, 218 | memory_key_padding_mask: Optional[Tensor] = None, 219 | pos: Optional[Tensor] = None, 220 | query_pos: Optional[Tensor] = None): 221 | tgt2 = self.norm1(tgt) 222 | q = k = self.with_pos_embed(tgt2, query_pos) 223 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 224 | key_padding_mask=tgt_key_padding_mask)[0] 225 | tgt = tgt + self.dropout1(tgt2) 226 | tgt2 = self.norm2(tgt) 227 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 228 | key=self.with_pos_embed(memory, pos), 229 | value=memory, attn_mask=memory_mask, 230 | key_padding_mask=memory_key_padding_mask)[0] 231 | tgt = tgt + self.dropout2(tgt2) 232 | tgt2 = self.norm3(tgt) 233 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 234 | tgt = tgt + self.dropout3(tgt2) 235 | return tgt 236 | 237 | def forward(self, tgt, memory, 238 | tgt_mask: Optional[Tensor] = None, 239 | memory_mask: Optional[Tensor] = None, 240 | tgt_key_padding_mask: Optional[Tensor] = None, 241 | memory_key_padding_mask: Optional[Tensor] = None, 242 | pos: Optional[Tensor] = None, 243 | query_pos: Optional[Tensor] = None): 244 | if self.normalize_before: 245 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 246 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 247 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 248 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 249 | 250 | 251 | def _get_clones(module, N): 252 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 253 | 254 | 255 | def build_transformer(args): 256 | return Transformer( 257 | d_model=args.hidden_dim, 258 | dropout=args.dropout, 259 | nhead=args.nheads, 260 | dim_feedforward=args.dim_feedforward, 261 | num_encoder_layers=args.enc_layers, 262 | num_decoder_layers=args.dec_layers, 263 | normalize_before=args.pre_norm, 264 | return_intermediate_dec=True, 265 | args=args 266 | ) 267 | 268 | 269 | def _get_activation_fn(activation): 270 | """Return an activation function given a string""" 271 | if activation == "relu": 272 | return F.relu 273 | if activation == "gelu": 274 | return F.gelu 275 | if activation == "glu": 276 | return F.glu 277 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 278 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | salesforce-lavis 2 | scikit-learn 3 | numpy==1.24.4 4 | -------------------------------------------------------------------------------- /tools/launch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------------------------------- 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # -------------------------------------------------------------------------------------------------------------------------- 6 | # Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py 7 | # -------------------------------------------------------------------------------------------------------------------------- 8 | 9 | r""" 10 | `torch.distributed.launch` is a module that spawns up multiple distributed 11 | training processes on each of the training nodes. 12 | The utility can be used for single-node distributed training, in which one or 13 | more processes per node will be spawned. The utility can be used for either 14 | CPU training or GPU training. If the utility is used for GPU training, 15 | each distributed process will be operating on a single GPU. This can achieve 16 | well-improved single-node training performance. It can also be used in 17 | multi-node distributed training, by spawning up multiple processes on each node 18 | for well-improved multi-node distributed training performance as well. 19 | This will especially be benefitial for systems with multiple Infiniband 20 | interfaces that have direct-GPU support, since all of them can be utilized for 21 | aggregated communication bandwidth. 22 | In both cases of single-node distributed training or multi-node distributed 23 | training, this utility will launch the given number of processes per node 24 | (``--nproc_per_node``). If used for GPU training, this number needs to be less 25 | or euqal to the number of GPUs on the current system (``nproc_per_node``), 26 | and each process will be operating on a single GPU from *GPU 0 to 27 | GPU (nproc_per_node - 1)*. 28 | **How to use this module:** 29 | 1. Single-Node multi-process distributed training 30 | :: 31 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 32 | YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other 33 | arguments of your training script) 34 | 2. Multi-Node multi-process distributed training: (e.g. two nodes) 35 | Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* 36 | :: 37 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 38 | --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" 39 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 40 | and all other arguments of your training script) 41 | Node 2: 42 | :: 43 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 44 | --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" 45 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 46 | and all other arguments of your training script) 47 | 3. To look up what optional arguments this module offers: 48 | :: 49 | >>> python -m torch.distributed.launch --help 50 | **Important Notices:** 51 | 1. This utilty and multi-process distributed (single-node or 52 | multi-node) GPU training currently only achieves the best performance using 53 | the NCCL distributed backend. Thus NCCL backend is the recommended backend to 54 | use for GPU training. 55 | 2. In your training program, you must parse the command-line argument: 56 | ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module. 57 | If your training program uses GPUs, you should ensure that your code only 58 | runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: 59 | Parsing the local_rank argument 60 | :: 61 | >>> import argparse 62 | >>> parser = argparse.ArgumentParser() 63 | >>> parser.add_argument("--local_rank", type=int) 64 | >>> args = parser.parse_args() 65 | Set your device to local rank using either 66 | :: 67 | >>> torch.cuda.set_device(arg.local_rank) # before your code runs 68 | or 69 | :: 70 | >>> with torch.cuda.device(arg.local_rank): 71 | >>> # your code to run 72 | 3. In your training program, you are supposed to call the following function 73 | at the beginning to start the distributed backend. You need to make sure that 74 | the init_method uses ``env://``, which is the only supported ``init_method`` 75 | by this module. 76 | :: 77 | torch.distributed.init_process_group(backend='YOUR BACKEND', 78 | init_method='env://') 79 | 4. In your training program, you can either use regular distributed functions 80 | or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your 81 | training program uses GPUs for training and you would like to use 82 | :func:`torch.nn.parallel.DistributedDataParallel` module, 83 | here is how to configure it. 84 | :: 85 | model = torch.nn.parallel.DistributedDataParallel(model, 86 | device_ids=[arg.local_rank], 87 | output_device=arg.local_rank) 88 | Please ensure that ``device_ids`` argument is set to be the only GPU device id 89 | that your code will be operating on. This is generally the local rank of the 90 | process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, 91 | and ``output_device`` needs to be ``args.local_rank`` in order to use this 92 | utility 93 | 5. Another way to pass ``local_rank`` to the subprocesses via environment variable 94 | ``LOCAL_RANK``. This behavior is enabled when you launch the script with 95 | ``--use_env=True``. You must adjust the subprocess example above to replace 96 | ``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher 97 | will not pass ``--local_rank`` when you specify this flag. 98 | .. warning:: 99 | ``local_rank`` is NOT globally unique: it is only unique per process 100 | on a machine. Thus, don't use it to decide if you should, e.g., 101 | write to a networked filesystem. See 102 | https://github.com/pytorch/pytorch/issues/12042 for an example of 103 | how things can go wrong if you don't do this correctly. 104 | """ 105 | 106 | 107 | import sys 108 | import subprocess 109 | import os 110 | import socket 111 | from argparse import ArgumentParser, REMAINDER 112 | 113 | import torch 114 | 115 | 116 | def parse_args(): 117 | """ 118 | Helper function parsing the command line options 119 | @retval ArgumentParser 120 | """ 121 | parser = ArgumentParser(description="PyTorch distributed training launch " 122 | "helper utilty that will spawn up " 123 | "multiple distributed processes") 124 | 125 | # Optional arguments for the launch helper 126 | parser.add_argument("--nnodes", type=int, default=1, 127 | help="The number of nodes to use for distributed " 128 | "training") 129 | parser.add_argument("--node_rank", type=int, default=0, 130 | help="The rank of the node for multi-node distributed " 131 | "training") 132 | parser.add_argument("--nproc_per_node", type=int, default=1, 133 | help="The number of processes to launch on each node, " 134 | "for GPU training, this is recommended to be set " 135 | "to the number of GPUs in your system so that " 136 | "each process can be bound to a single GPU.") 137 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 138 | help="Master node (rank 0)'s address, should be either " 139 | "the IP address or the hostname of node 0, for " 140 | "single node multi-proc training, the " 141 | "--master_addr can simply be 127.0.0.1") 142 | parser.add_argument("--master_port", default=29500, type=int, 143 | help="Master node (rank 0)'s free port that needs to " 144 | "be used for communciation during distributed " 145 | "training") 146 | 147 | # positional 148 | parser.add_argument("training_script", type=str, 149 | help="The full path to the single GPU training " 150 | "program/script to be launched in parallel, " 151 | "followed by all the arguments for the " 152 | "training script") 153 | 154 | # rest from the training program 155 | parser.add_argument('training_script_args', nargs=REMAINDER) 156 | return parser.parse_args() 157 | 158 | 159 | def main(): 160 | args = parse_args() 161 | 162 | # world size in terms of number of processes 163 | dist_world_size = args.nproc_per_node * args.nnodes 164 | 165 | # set PyTorch distributed related environmental variables 166 | current_env = os.environ.copy() 167 | current_env["MASTER_ADDR"] = args.master_addr 168 | current_env["MASTER_PORT"] = str(args.master_port) 169 | current_env["WORLD_SIZE"] = str(dist_world_size) 170 | 171 | processes = [] 172 | 173 | for local_rank in range(0, args.nproc_per_node): 174 | # each process's rank 175 | dist_rank = args.nproc_per_node * args.node_rank + local_rank 176 | current_env["RANK"] = str(dist_rank) 177 | current_env["LOCAL_RANK"] = str(local_rank) 178 | 179 | cmd = [args.training_script] + args.training_script_args 180 | #import pdb; pdb.set_trace() 181 | 182 | process = subprocess.Popen(cmd, env=current_env) 183 | 184 | processes.append(process) 185 | for process in processes: 186 | process.wait() 187 | if process.returncode != 0: 188 | raise subprocess.CalledProcessError(returncode=process.returncode, 189 | cmd=process.args) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() -------------------------------------------------------------------------------- /tools/run_dist_launch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | GPUS=$1 3 | RUN_COMMAND=${@:2} 4 | if [ $GPUS -lt 8 ]; then 5 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} 6 | else 7 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 8 | fi 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | # Generate a random port number between 20000 and 60000 12 | MASTER_PORT=$((20000 + RANDOM % 40000)) 13 | NODE_RANK=${NODE_RANK:-0} 14 | 15 | let "NNODES=GPUS/GPUS_PER_NODE" 16 | 17 | python ./tools/launch.py \ 18 | --nnodes ${NNODES} \ 19 | --node_rank ${NODE_RANK} \ 20 | --master_addr ${MASTER_ADDR} \ 21 | --master_port ${MASTER_PORT} \ 22 | --nproc_per_node ${GPUS_PER_NODE} \ 23 | ${RUN_COMMAND} 24 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/util/__init__.py --------------------------------------------------------------------------------