├── README.md
├── assets
├── .gitkeep
└── SugaFormer.png
├── configs
├── .gitkeep
└── vaw
│ ├── eval_fs.sh
│ ├── eval_zs.sh
│ ├── train_fs.sh
│ └── train_zs.sh
├── data
└── .gitkeep
├── datasets
├── __init__.py
├── transforms.py
├── vaw.py
└── vaw_eval.py
├── engine.py
├── main.py
├── models
├── __init__.py
├── backbone.py
├── lavis
│ ├── __init__.py
│ ├── common
│ │ ├── config.py
│ │ ├── dist_utils.py
│ │ ├── gradcam.py
│ │ ├── logger.py
│ │ ├── optims.py
│ │ ├── registry.py
│ │ └── utils.py
│ ├── configs
│ │ └── default.yaml
│ ├── models
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── blip2_models
│ │ │ ├── Qformer.py
│ │ │ ├── __init__.py
│ │ │ ├── blip2.py
│ │ │ └── blip2_qformer.py
│ │ ├── blip_models
│ │ │ ├── __init__.py
│ │ │ ├── blip.py
│ │ │ └── blip_outputs.py
│ │ ├── clip_vit.py
│ │ ├── eva_vit.py
│ │ ├── med.py
│ │ └── vit.py
│ ├── processors
│ │ ├── __init__.py
│ │ ├── base_processor.py
│ │ ├── blip_processors.py
│ │ ├── clip_processors.py
│ │ └── randaugment.py
│ └── runners
│ │ ├── __init__.py
│ │ ├── runner_base.py
│ │ └── runner_iter.py
├── position_encoding.py
├── sugaformer.py
└── transformer.py
├── requirements.txt
├── tools
├── launch.py
└── run_dist_launch.sh
└── util
├── __init__.py
└── misc.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Super-class guided Transformer for Zero-Shot Attribute Classification
3 |
4 | Sehyung Kim*, Chanhyeong Yang*, Jihwan Park, Taehoon Song, Hyunwoo J. Kim†.
5 | AAAI 2025
6 |
7 |
12 |
13 |
14 | ----
15 |
16 | 
17 |
18 | This is the official implementation of AAAI 2025 paper "Super-class guided Transformer for Zero-Shot Attribute Classification"
19 |
20 | ----
21 |
22 | ## Environment Setting
23 | ```bash
24 | git clone https://github.com/mlvlab/SugaFormer.git
25 | cd SugaFormer
26 | conda create -n sugaformer python==3.9
27 | conda activate sugaformer
28 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
29 | pip install -r requirements.txt
30 | ```
31 |
32 | ----
33 |
34 | ## Dataset Preparation
35 |
36 | To run experiments for VAW, you need both the images from the Visual Genome dataset and the annotation files. Follow the steps below:
37 |
38 | 1. Download the Visual Genome images from the [link](https://homes.cs.washington.edu/~ranjay/visualgenome/index.html).
39 | 2. Download the annotation files for VAW experiments from the [link](https://drive.google.com/drive/folders/1qW3HkMcdLHnsCDXn00TyFD4rAlErZRRW?usp=drive_link).
40 |
41 |
42 | ### Organize the Data
43 | After downloading the Visual Genome images and annotation files, organize them into the following directory structure:
44 |
45 | ```bash
46 |
47 | data/
48 | └── vaw/
49 | ├── images/
50 | │ ├── VG_100K/
51 | │ └── VG_100K_2/
52 | │
53 | └── annotations/
54 | ├── train.json
55 | ├── test.json
56 | ├── ...
57 |
58 | ```
59 |
60 | ## Training
61 |
62 | ### VAW Fully-Supervised
63 | Train the model in the fully-supervised setting:
64 | ```bash
65 | ./configs/vaw/train_fs.sh
66 | ```
67 |
68 | ### VAW Zero-Shot (base2novel)
69 | Train the model in the zero-shot setting:
70 | ```bash
71 | ./configs/vaw/train_zs.sh
72 | ```
73 | ## Evaluation
74 |
75 | ### VAW Fully-Supervised
76 | Evaluate the model in the fully-supervised setting:
77 | ```bash
78 | ./configs/vaw/eval_fs.sh
79 | ```
80 | ### VAW Zero-Shot (base2novel)
81 | Evaluate the model in the zero-shot setting:
82 | ```bash
83 | ./configs/vaw/eval_zs.sh
84 | ```
85 |
86 | ## Acknowledgements
87 | This repository is built upon the following works:
88 |
89 | * [DETR (Facebook Research)](https://github.com/facebookresearch/detr): The codebase we built upon and the foundation for our base model.
90 |
91 | * [LAVIS (Salesforce)](https://github.com/salesforce/LAVIS): Pre-trained Vision-Language Models (BLIP2) that we utilized for feature extraction and knowledge transfer.
92 |
93 | ## Contact
94 | If you have any questions, please create an issue on this repository or contact at shkim129@korea.ac.kr.
95 |
96 | ## Citation
97 | If you find our work interesting, please consider giving a ⭐ and citation.
98 | ```bibtex
99 | @article{kim2025super,
100 | title={Super-class guided Transformer for Zero-Shot Attribute Classification},
101 | author={Kim, Sehyung and Yang, Chanhyeong and Park, Jihwan and Song, Taehoon and Kim, Hyunwoo J},
102 | journal={arXiv preprint arXiv:2501.05728},
103 | year={2025}
104 | }
105 | ```
106 |
--------------------------------------------------------------------------------
/assets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/assets/.gitkeep
--------------------------------------------------------------------------------
/assets/SugaFormer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/assets/SugaFormer.png
--------------------------------------------------------------------------------
/configs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/configs/.gitkeep
--------------------------------------------------------------------------------
/configs/vaw/eval_fs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=4,5,6,7
4 | export GPUS_PER_NODE=4
5 |
6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \
7 | python -u main.py \
8 | --dataset_file vaw \
9 | --mode supervised \
10 | --eval \
11 | --zrse \
12 | --dec_layers 3 \
13 | --batch_size 4 \
14 | --pretrained exps/vaw/supervised/checkpoint.pth \
15 | --output_dir exps/vaw/supervised/
16 |
17 |
--------------------------------------------------------------------------------
/configs/vaw/eval_zs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3
4 | export GPUS_PER_NODE=4
5 |
6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \
7 | python -u main.py \
8 | --dataset_file vaw \
9 | --mode zero_shot \
10 | --eval \
11 | --zrse \
12 | --dec_layers 3 \
13 | --batch_size 4 \
14 | --pretrained exps/vaw/zero_shot/checkpoint.pth \
15 | --output_dir exps/vaw/zero_shot/
16 |
--------------------------------------------------------------------------------
/configs/vaw/train_fs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=4,5,6,7
4 | export GPUS_PER_NODE=4
5 |
6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \
7 | python -u main.py \
8 | --dataset_file vaw \
9 | --mode supervised \
10 | --scr_coef 2 \
11 | --att_loss_coef 1 \
12 | --dec_layers 3 \
13 | --epochs 15 \
14 | --lr_drop 13 \
15 | --batch_size 4 \
16 | --output_dir exps/vaw/supervised/
--------------------------------------------------------------------------------
/configs/vaw/train_zs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3
4 | export GPUS_PER_NODE=4
5 |
6 | ./tools/run_dist_launch.sh $GPUS_PER_NODE \
7 | python -u main.py \
8 | --dataset_file vaw \
9 | --mode zero_shot \
10 | --scr_coef 2 \
11 | --att_loss_coef 1 \
12 | --dec_layers 3 \
13 | --epochs 9 \
14 | --batch_size 4 \
15 | --output_dir exps/vaw/zero_shot/
16 |
17 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/data/.gitkeep
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .vaw import build as build_vaw
2 | def build_dataset(image_set, args):
3 |
4 | if args.dataset_file == 'vaw':
5 | return build_vaw(image_set, args)
6 |
7 | raise ValueError(f'dataset {args.dataset_file} not supported')
8 |
--------------------------------------------------------------------------------
/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as F
2 | import torchvision.transforms as T
3 | import numpy as np
4 | import random
5 | import torch
6 | import PIL
7 | import cv2
8 |
9 | def crop(image, target, region):
10 | cropped_image = F.crop(image, *region)
11 |
12 | target = target.copy()
13 | i, j, h, w = region
14 |
15 | target["size"] = torch.tensor([h, w])
16 |
17 | fields = ["labels", "area", "iscrowd"]
18 | if "pos_att_classes" in target.keys():
19 | fields.append("pos_att_classes")
20 | if "neg_att_classes" in target.keys():
21 | fields.append("neg_att_classes")
22 | if "boxes" in target:
23 | boxes = target["boxes"]
24 | max_size = torch.as_tensor([w, h], dtype=torch.float32)
25 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
26 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
27 | cropped_boxes = cropped_boxes.clamp(min=0)
28 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
29 | target["boxes"] = cropped_boxes.reshape(-1, 4)
30 | target["area"] = area
31 | fields.append("boxes")
32 |
33 | if "masks" in target:
34 | # FIXME should we update the area here if there are no boxes?
35 | target['masks'] = target['masks'][:, i:i + h, j:j + w]
36 | fields.append("masks")
37 |
38 | # remove elements for which the boxes or masks that have zero area
39 | if "boxes" in target or "masks" in target:
40 | # favor boxes selection when defining which elements to keep
41 | # this is compatible with previous implementation
42 | if "boxes" in target:
43 | cropped_boxes = target['boxes'].reshape(-1, 2, 2)
44 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
45 | else:
46 | keep = target['masks'].flatten(1).any(1)
47 |
48 | for field in fields:
49 | target[field] = target[field][keep]
50 |
51 | return cropped_image, target
52 |
53 |
54 | def hflip(image, target):
55 | flipped_image = F.hflip(image)
56 |
57 | w, h = image.size
58 |
59 | target = target.copy()
60 | if "boxes" in target:
61 | boxes = target["boxes"]
62 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
63 | target["boxes"] = boxes
64 |
65 | if "masks" in target:
66 | target['masks'] = target['masks'].flip(-1)
67 |
68 | return flipped_image, target
69 |
70 |
71 | def resize(image, target, size, max_size=None):
72 | def get_size_with_aspect_ratio(image_size, size, max_size=None):
73 | w, h = image_size
74 | if max_size is not None:
75 | min_original_size = float(min((w, h)))
76 | max_original_size = float(max((w, h)))
77 | if max_original_size / min_original_size * size > max_size:
78 | size = int(round(max_size * min_original_size / max_original_size))
79 |
80 | if (w <= h and w == size) or (h <= w and h == size):
81 | return (h, w)
82 |
83 | if w < h:
84 | ow = size
85 | oh = int(size * h / w)
86 | else:
87 | oh = size
88 | ow = int(size * w / h)
89 |
90 | return (oh, ow)
91 |
92 | def get_size(image_size, size, max_size=None):
93 | if isinstance(size, (list, tuple)):
94 | return size[::-1]
95 | else:
96 | return get_size_with_aspect_ratio(image_size, size, max_size)
97 |
98 | size = get_size(image.size, size, max_size)
99 | rescaled_image = F.resize(image, size)
100 |
101 | if target is None:
102 | return rescaled_image, None
103 |
104 | return rescaled_image, target
105 |
106 |
107 | def pad(image, target, padding):
108 | # assumes that we only pad on the bottom right corners
109 | padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
110 | if target is None:
111 | return padded_image, None
112 | target = target.copy()
113 | # should we do something wrt the original size?
114 | target["size"] = torch.tensor(padded_image[::-1])
115 | if "masks" in target:
116 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
117 | return padded_image, target
118 |
119 |
120 | class RandomCrop(object):
121 | def __init__(self, size):
122 | self.size = size
123 |
124 | def __call__(self, img, target):
125 | region = T.RandomCrop.get_params(img, self.size)
126 | return crop(img, target, region)
127 |
128 |
129 | class RandomSizeCrop(object):
130 | def __init__(self, min_size: int, max_size: int):
131 | self.min_size = min_size
132 | self.max_size = max_size
133 |
134 | def __call__(self, img: PIL.Image.Image, target: dict):
135 | w = random.randint(self.min_size, min(img.width, self.max_size))
136 | h = random.randint(self.min_size, min(img.height, self.max_size))
137 | region = T.RandomCrop.get_params(img, [h, w])
138 | return crop(img, target, region)
139 |
140 |
141 | class CenterCrop(object):
142 | def __init__(self, size):
143 | self.size = size
144 |
145 | def __call__(self, img, target):
146 | image_width, image_height = img.size
147 | crop_height, crop_width = self.size
148 | crop_top = int(round((image_height - crop_height) / 2.))
149 | crop_left = int(round((image_width - crop_width) / 2.))
150 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
151 |
152 |
153 | class RandomHorizontalFlip(object):
154 | def __init__(self, p=0.5):
155 | self.p = p
156 |
157 | def __call__(self, img, target):
158 | if random.random() < self.p:
159 | return hflip(img, target)
160 | return img, target
161 |
162 |
163 | class RandomResize(object):
164 | def __init__(self, sizes, max_size=None):
165 | assert isinstance(sizes, (list, tuple))
166 | self.sizes = sizes
167 | self.max_size = max_size
168 |
169 | def __call__(self, img, target=None):
170 | size = random.choice(self.sizes)
171 | return resize(img, target, size, self.max_size)
172 |
173 |
174 | class RandomPad(object):
175 | def __init__(self, max_pad):
176 | self.max_pad = max_pad
177 |
178 | def __call__(self, img, target):
179 | pad_x = random.randint(0, self.max_pad)
180 | pad_y = random.randint(0, self.max_pad)
181 | return pad(img, target, (pad_x, pad_y))
182 |
183 |
184 | class RandomSelect(object):
185 | def __init__(self, transforms1, transforms2, p=0.5):
186 | self.transforms1 = transforms1
187 | self.transforms2 = transforms2
188 | self.p = p
189 |
190 | def __call__(self, img, target):
191 | if random.random() < self.p:
192 | return self.transforms1(img, target)
193 | return self.transforms2(img, target)
194 |
195 |
196 | class ToTensor(object):
197 | def __call__(self, img):
198 | return F.to_tensor(img)
199 |
200 |
201 | class RandomErasing(object):
202 |
203 | def __init__(self, *args, **kwargs):
204 | self.eraser = T.RandomErasing(*args, **kwargs)
205 |
206 | def __call__(self, img, target):
207 | return self.eraser(img), target
208 |
209 |
210 | class Normalize(object):
211 | def __init__(self, mean, std):
212 | self.mean = mean
213 | self.std = std
214 |
215 | def __call__(self, image, target=None):
216 | image = F.normalize(image, mean=self.mean, std=self.std)
217 | if target is None:
218 | return image
219 | return image, target
220 |
221 |
222 | class Compose(object):
223 | def __init__(self, transforms):
224 | self.transforms = transforms
225 |
226 | def __call__(self, image, target=None):
227 | for t in self.transforms:
228 | if isinstance(t, T.Normalize): # Apply only to the image
229 | image = t(image)
230 | else:
231 | image, target = t(image, target)
232 | return image, target
233 |
234 |
235 | class ColorJitter(object):
236 | def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0):
237 | self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue)
238 |
239 | def __call__(self, img, target):
240 | return self.color_jitter(img), target
241 |
242 |
243 | def polygon_to_mask(polygon, height, width):
244 | mask = np.zeros((height, width), dtype=np.uint8)
245 | polygon = np.array(polygon, dtype=np.int32).reshape((-1, 2))
246 | cv2.fillPoly(mask, [polygon], 1) # Fill the polygon area with 1
247 | return mask
248 |
--------------------------------------------------------------------------------
/datasets/vaw.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import datasets.transforms as T
3 | from pathlib import Path
4 | import torch.utils.data
5 | from PIL import Image
6 | import numpy as np
7 | import torch
8 | import json
9 |
10 | class VAW_Dataloader(torch.utils.data.Dataset):
11 | def __init__(self, img_set, img_folder, anno_file, attribute_index, transforms, args=None):
12 | self.att_ids = list(self.load_json(attribute_index).values())
13 | self.annotations = self.load_json(anno_file)
14 | self.img_set = img_set
15 | self.img_folder = img_folder
16 | self.transforms = transforms
17 | self.sc_token_path = args.use_scr
18 |
19 | def load_json(self, file_path):
20 | with open(file_path, 'r') as f:
21 | return json.load(f)
22 |
23 | def num_attributes(self):
24 | return len(self.att_ids)
25 |
26 | def __len__(self):
27 | return len(self.annotations)
28 |
29 | def __getitem__(self, idx):
30 | img_anno = self.annotations[idx]
31 | img_id = str(img_anno['image_id'])
32 | object_names = img_anno['object_name']
33 | file_dir = self.get_file_path(img_anno['file_name'])
34 | img = Image.open(self.img_folder / file_dir).convert('RGB')
35 | boxes, orig_size, cropped_masks, keep = self.process_boxes_and_mask(img_anno, img)
36 | pos_att_classes, neg_att_classes = self.create_attribute_classes(img_anno)
37 | pos_att_classes = pos_att_classes[keep]
38 | neg_att_classes = neg_att_classes[keep]
39 | if self.img_set == 'train':
40 | sc_mask_output = self.get_sc_mask_output(img_id)
41 | sc_mask_output = sc_mask_output[keep]
42 | target = self.create_target_dict(img_anno, boxes, pos_att_classes, neg_att_classes, orig_size, object_names, sc_mask_output)
43 | elif self.img_set == 'test':
44 | target = self.create_target_dict(img_anno, boxes, pos_att_classes, neg_att_classes, orig_size, object_names)
45 | transformed_img, target = self.apply_transforms(img, target)
46 | crop_imgs = self.crop_and_normalize_boxes(img, boxes)
47 | return transformed_img, crop_imgs, cropped_masks, target
48 |
49 | def get_file_path(self, file_name):
50 | return file_name.split('/')[-2] + '/' + file_name.split('/')[-1]
51 |
52 | def process_boxes_and_mask(self, img_anno, img):
53 | boxes = torch.as_tensor(img_anno['boxes'], dtype=torch.float32).reshape(-1, 4)
54 | w, h = img.size
55 | orig_size = torch.as_tensor([h, w])
56 | boxes[:, 2:] += boxes[:, :2]
57 | boxes[:, 0::2].clamp_(min=0, max=w)
58 | boxes[:, 1::2].clamp_(min=0, max=h)
59 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
60 | boxes = boxes[keep]
61 | mask_path = img_anno['mask']
62 | masks = torch.from_numpy(np.load(mask_path))
63 | masks = masks[keep]
64 | cropped_masks = self.crop_masks(masks, boxes, orig_size)
65 | return boxes, orig_size, cropped_masks, keep
66 |
67 | def create_attribute_classes(self, img_anno):
68 | pos_att_classes = torch.zeros((len(img_anno['boxes']), self.num_attributes()), dtype=torch.float32)
69 | neg_att_classes = torch.zeros((len(img_anno['boxes']), self.num_attributes()), dtype=torch.float32)
70 | for b, pos_id in zip(pos_att_classes, img_anno['pos_att_id']):
71 | b[pos_id] = 1
72 | for b, neg_id in zip(neg_att_classes, img_anno['neg_att_id']):
73 | b[neg_id] = 1
74 | return pos_att_classes, neg_att_classes
75 |
76 | def get_sc_mask_output(self, img_id):
77 | sc_mask_output = torch.load(self.sc_token_path + img_id + '.pt', map_location='cpu')
78 | return sc_mask_output
79 |
80 | def create_target_dict(self, img_anno, boxes, pos_att_classes, neg_att_classes, orig_size, object_names, sc_mask_output=None):
81 | target = {
82 | 'boxes': boxes,
83 | 'pos_att_classes': pos_att_classes,
84 | 'neg_att_classes': neg_att_classes,
85 | 'img_id': img_anno['image_id'],
86 | 'orig_size': orig_size,
87 | 'obj_names': object_names,
88 | }
89 | if sc_mask_output is not None:
90 | target['sc_token_output'] = sc_mask_output
91 | return target
92 |
93 | def apply_transforms(self, img, target):
94 | if self.transforms:
95 | transformed_img, _ = self.transforms(img, target)
96 | return transformed_img, target
97 |
98 | def crop_and_normalize_boxes(self, img, boxes):
99 | crop_imgs = []
100 | for box in boxes:
101 | box = box.int()
102 | cropped_img = img.crop((box[0].item(), box[1].item(), box[2].item(), box[3].item()))
103 | cropped_img = T.ToTensor()(cropped_img)
104 | cropped_img, _ = T.resize(cropped_img, None, size=(224, 224))
105 | cropped_img = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(cropped_img)
106 | crop_imgs.append(cropped_img)
107 | return torch.stack(crop_imgs)
108 |
109 | def crop_masks(self, masks, boxes, orig_size):
110 | resized_masks = []
111 | for mask, box in zip(masks, boxes):
112 | crop_resized_mask = self.mask_crop_resize(mask.unsqueeze(0), orig_size, box)
113 | resized_masks.append(crop_resized_mask)
114 | return torch.cat(resized_masks)
115 |
116 | def mask_crop_resize(self, mask, orig_size, box, img_size=224):
117 | orig_h, orig_w = orig_size[0].item(), orig_size[1].item()
118 | box[0::2].clamp_(min=0, max=orig_w)
119 | box[1::2].clamp_(min=0, max=orig_h)
120 | cropped_mask = mask[:, :orig_h, :orig_w][:, int(box[1]):int(box[3]), int(box[0]):int(box[2])]
121 | if True in cropped_mask:
122 | resized_mask = F.interpolate(cropped_mask.unsqueeze(0).type(torch.float), size=(img_size, img_size), mode='bicubic', align_corners=False)
123 | resized_mask = (resized_mask > 0.5).float()
124 | else:
125 | resized_mask = torch.ones((1, 1, img_size, img_size))
126 | return resized_mask
127 |
128 | def make_vaw_transforms(image_set, img_size=224):
129 |
130 | def transform(img, target):
131 | img = T.ToTensor()(img)
132 | img, _ = T.resize(img, None, size=(img_size, img_size))
133 | img = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
134 | return img, target
135 |
136 | if image_set in ['train', 'test']:
137 | return transform
138 |
139 | def build(image_set, args):
140 | root = Path(args.vaw_path)
141 | assert root.exists(), f'Provided VAW path {root} does not exist'
142 | PATHS = {
143 | 'train': (root / 'images', root / 'annotations' / 'train.json'),
144 | 'val': (root / 'images', root / 'annotations' / 'test.json'),
145 | 'test': (root / 'images', root / 'annotations' / 'test.json')
146 | }
147 |
148 | attribute_index = root / 'annotations' / 'attribute_index.json'
149 |
150 | img_folder, anno_file = PATHS[image_set]
151 | assert img_folder.exists(), f"Image folder {img_folder} does not exist"
152 | assert anno_file.exists(), f"Annotation file {anno_file} does not exist"
153 |
154 | transforms = make_vaw_transforms(image_set)
155 | dataset = VAW_Dataloader(image_set, img_folder, anno_file, attribute_index, transforms, args)
156 |
157 | return dataset
158 |
--------------------------------------------------------------------------------
/datasets/vaw_eval.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import average_precision_score
2 | import numpy as np
3 | import torch
4 | import json
5 |
6 | K = 15
7 | def top_K_values(array):
8 | """Keeps only topK largest values in array.
9 | """
10 | indexes = np.argpartition(array, -K, axis=-1)[-K:]
11 | A = set(indexes)
12 | B = set(list(range(array.shape[0])))
13 | B -= A
14 | array[list(B)]=0
15 | return array
16 |
17 | class Evaluator(object):
18 | def __init__(
19 | self,
20 | fpath_attr2idx,
21 | fpath_attr_headtail,
22 | threshold=0.5,
23 | exclude_atts=[]
24 | ):
25 | """Initializes evaluator for attribute prediction on VAW dataset.
26 | Args:
27 | - fpath_attr2idx: path to attribute class index file.
28 | - fpath_attr_headtail: path to attribute head/mid/tail categorization file.
29 | - threshold: positive/negative threshold (for Accuracy metric).
30 | - exclude_atts: any attribute classes to be excluded from evaluation.
31 | """
32 |
33 | # Read file that maps from id to attribute name.
34 | with open(fpath_attr2idx, 'r') as f:
35 | self.attr2idx = json.load(f)
36 | self.idx2attr = {v:k for k, v in self.attr2idx.items()}
37 |
38 | # Read file that shows whether attribute is head/mid/tail.
39 | with open(fpath_attr_headtail, 'r') as f:
40 | self.attribute_head_tail = json.load(f)
41 |
42 | self.n_class = len(self.attr2idx)
43 | self.exclude_atts = exclude_atts
44 | self.threshold = threshold
45 |
46 | # Cache metric score for each class.
47 | self.score = {}
48 | self.score_topk = {}
49 |
50 | def _clear_cache(self):
51 | self.score = {}
52 | self.score_topk = {}
53 |
54 | def get_attr_head_tail(self, attr):
55 | """Finds whether attribute is in head/medium/tail group.
56 | """
57 | for group, L in self.attribute_head_tail.items():
58 | if attr in L:
59 | return group
60 | assert False, f"Can't find head/medium/tail group for {attr}"
61 |
62 | def evaluate(
63 | self,
64 | pred,
65 | gt_label,
66 | threshold_type='threshold'
67 | ):
68 | """Evaluates a prediction matrix against groundtruth label.
69 | Args:
70 | - pred: prediction matrix [n_instance, n_class].
71 | pred[i,j] is the j-th attribute score of instance i-th.
72 | These scores should be from 0 -> 1.
73 | - gt_label: groundtruth label matrix [n_instances, n_class].
74 | gt_label[i,j] = 1 if instance i is positively labeled with
75 | attribute j, = 0 if it is negatively labeled, and = 2 if
76 | it is unlabeled.
77 | - threshold_type: 'threshold' or 'topk'.
78 | Determines positive vs. negative prediction.
79 | """
80 | self.pred = pred
81 | self.gt_label = gt_label
82 | self.n_instance = self.gt_label.shape[0]
83 |
84 | # For topK metrics, we keep a version of the prediction matrix that sets
85 | # non-topK elements as 0 and topK elements as 1.
86 | P_topk = self.pred.detach().cpu().numpy().copy()
87 | P_topk = np.apply_along_axis(top_K_values, 1, P_topk)
88 | P_topk[P_topk > 0] = 1
89 | self.pred_topk = P_topk
90 | all_groups = ['all', 'head', 'medium', 'tail']
91 | groups_overall = {
92 | k: GroupClassMetric(metric_type='overall')
93 | for k in all_groups
94 | }
95 | groups_per_class = {
96 | k: GroupClassMetric(metric_type='per-class')
97 | for k in all_groups
98 | }
99 | class_metric_dict = {}
100 | for i_class in range(self.n_class):
101 | attr = self.idx2attr[i_class]
102 | if attr in self.exclude_atts:
103 | continue
104 |
105 | class_metric = self.get_score_class(i_class, threshold_type=threshold_type)
106 | class_metric_dict[i_class] = class_metric
107 |
108 | # Add to 'all' group.
109 | groups_overall['all'].add_class(class_metric)
110 | groups_per_class['all'].add_class(class_metric)
111 |
112 | # Add to head/medium/tail group.
113 | imbalance_group = self.get_attr_head_tail(attr)
114 | groups_overall[imbalance_group].add_class(class_metric)
115 | groups_per_class[imbalance_group].add_class(class_metric)
116 |
117 |
118 |
119 | # Aggregate final scores.
120 | # For overall, we're interested in F1.
121 | # For per-class, we're interested in mean AP, mean recall, mean balanced accuracy.
122 | scores_overall = {}
123 | for group_name, group in groups_overall.items():
124 | scores_overall[group_name] = {
125 | 'f1': group.get_f1(),
126 | 'precision': group.get_precision(),
127 | 'recall': group.get_recall(),
128 | 'tnr': group.get_tnr(),
129 | }
130 | scores_per_class = {}
131 | for group_name, group in groups_per_class.items():
132 | scores_per_class[group_name] = {
133 | 'ap': group.get_ap(),
134 | 'f1': group.get_f1(),
135 | 'precision': group.get_precision(),
136 | 'recall': group.get_recall(),
137 | 'bacc': group.get_bacc()
138 | }
139 |
140 | return scores_per_class
141 |
142 | def get_score_class(self, i_class, threshold_type='threshold'):
143 | """Computes all metrics for a given class.
144 | Args:
145 | - i_class: class index.
146 | - threshold_type: 'topk' or 'threshold'. This determines how a
147 | prediction is positive or negative.
148 | """
149 | if threshold_type == 'threshold':
150 | score = self.score
151 | else:
152 | score = self.score_topk
153 | if i_class in score:
154 | return score[i_class]
155 |
156 | if threshold_type == 'threshold':
157 | pred = self.pred[:,i_class]
158 |
159 | else:
160 | pred = self.pred_topk[:,i_class]
161 | gt_label = self.gt_label[:,i_class]
162 |
163 | # Find instances that are explicitly labeled (either positive or negative).
164 | mask_labeled = (gt_label < 2)
165 | if mask_labeled.sum() == 0:
166 | # None of the instances have label for this class.
167 | # assert False, f"0 labeled instances for attribute {self.idx2attr[i_class]}"
168 | pass
169 | else:
170 | # Select ony the labeled ones.
171 | pred = pred[mask_labeled]
172 | gt_label = gt_label[mask_labeled]
173 |
174 | if threshold_type == 'threshold':
175 | # Only computes AP when threshold_type is 'threshold'. This is because when
176 | # threshold_type is 'topk', pred is a binary matrix.
177 | ap = average_precision_score(gt_label, pred)
178 |
179 | # Make pred into binary matrix.
180 | pred[pred > self.threshold] = 1
181 | pred[pred <= self.threshold] = 0
182 |
183 | class_metric = SingleClassMetric(pred, gt_label)
184 | if threshold_type == 'threshold':
185 | class_metric.ap = ap
186 |
187 | # Cache results.
188 | score[i_class] = class_metric
189 |
190 | return class_metric
191 |
192 |
193 | class GroupClassMetric(object):
194 | def __init__(self, metric_type):
195 | """This class computes all metrics for a group of attributes.
196 | Args:
197 | - metric_type: 'overall' or 'per-class'.
198 | """
199 | self.metric_type = metric_type
200 |
201 | if metric_type == 'overall':
202 | # Keep track of all stats.
203 | self.true_pos = 0
204 | self.false_pos = 0
205 | self.true_neg = 0
206 | self.false_neg = 0
207 | self.n_pos = 0
208 | self.n_neg = 0
209 | else:
210 | self.metric = {
211 | name: []
212 | for name in ['recall', 'tnr', 'acc', 'bacc', 'precision', 'f1', 'ap']
213 | }
214 |
215 | def add_class(self, class_metric):
216 | """Adds computed metrics of a class into this group.
217 | """
218 | if self.metric_type == 'overall':
219 | self.true_pos += class_metric.true_pos
220 | self.false_pos += class_metric.false_pos
221 | self.true_neg += class_metric.true_neg
222 | self.false_neg += class_metric.false_neg
223 | self.n_pos += class_metric.n_pos
224 | self.n_neg += class_metric.n_neg
225 | else:
226 | self.metric['recall'].append(class_metric.get_recall())
227 | self.metric['tnr'].append(class_metric.get_tnr())
228 | self.metric['acc'].append(class_metric.get_acc())
229 | self.metric['bacc'].append(class_metric.get_bacc())
230 | self.metric['precision'].append(class_metric.get_precision())
231 | self.metric['f1'].append(class_metric.get_f1())
232 | self.metric['ap'].append(class_metric.ap)
233 |
234 | def get_recall(self):
235 | """Computes recall.
236 | """
237 | if self.metric_type == 'overall':
238 | n_pos_pred = self.true_pos + self.false_pos
239 | if n_pos_pred == 0:
240 | # Model makes 0 positive prediction.
241 | # This is a special case: we fall back to precision = 1 and recall = 0.
242 | return 0
243 |
244 | if self.n_pos > 0:
245 | return self.true_pos / self.n_pos
246 | return -1
247 | else:
248 | if -1 not in self.metric['recall']:
249 | return np.mean(self.metric['recall'])
250 | return -1
251 |
252 | def get_tnr(self):
253 | """Computes true negative rate.
254 | """
255 | if self.metric_type == 'overall':
256 | if self.n_neg > 0:
257 | return self.true_neg / self.n_neg
258 | return -1
259 | else:
260 | if -1 not in self.metric['tnr']:
261 | return np.mean(self.metric['tnr'])
262 | return -1
263 |
264 | def get_acc(self):
265 | """Computes accuracy.
266 | """
267 | if self.metric_type == 'overall':
268 | if self.n_pos + self.n_neg > 0:
269 | return (self.true_pos + self.true_neg) / (self.n_pos + self.n_neg)
270 | return -1
271 | else:
272 | if -1 not in self.metric['acc']:
273 | return np.mean(self.metric['acc'])
274 | return -1
275 |
276 | def get_bacc(self):
277 | """Computes balanced accuracy.
278 | """
279 | if self.metric_type == 'overall':
280 | recall = self.get_recall()
281 | tnr = self.get_tnr()
282 | if recall == -1 or tnr == -1:
283 | return -1
284 | return (recall + tnr) / 2.0
285 | else:
286 | if -1 not in self.metric['bacc']:
287 | return np.mean(self.metric['bacc'])
288 | return -1
289 |
290 | def get_precision(self):
291 | """Computes precision.
292 | """
293 | if self.metric_type == 'overall':
294 | n_pos_pred = self.true_pos + self.false_pos
295 | if n_pos_pred == 0:
296 | # Model makes 0 positive prediction.
297 | # This is a special case: we fall back to precision = 1 and recall = 0.
298 | return 1
299 | return self.true_pos / n_pos_pred
300 | else:
301 | if -1 not in self.metric['precision']:
302 | return np.mean(self.metric['precision'])
303 | return -1
304 |
305 | def get_f1(self):
306 | """Computes F1.
307 | """
308 | if self.metric_type == 'overall':
309 | recall = self.get_recall()
310 | precision = self.get_precision()
311 | if precision + recall > 0:
312 | return 2 * precision * recall / (precision + recall)
313 | return 0
314 | else:
315 | if -1 not in self.metric['f1']:
316 | return np.mean(self.metric['f1'])
317 | return -1
318 |
319 | def get_ap(self):
320 | """Computes mAP.
321 | """
322 | assert self.metric_type == 'per-class'
323 | return np.mean(self.metric['ap'])
324 |
325 |
326 | class SingleClassMetric(object):
327 | def __init__(self, pred, gt_label):
328 | """This class computes all metrics for a single attribute.
329 | Args:
330 | - pred: np.array of shape [n_instance] -> binary prediction.
331 | - gt_label: np.array of shape [n_instance] -> groundtruth binary label.
332 | """
333 | if pred is None or gt_label is None:
334 | self.true_pos = 0
335 | self.false_pos = 0
336 | self.true_neg = 0
337 | self.false_neg = 0
338 | self.n_pos = 0
339 | self.n_neg = 0
340 | self.ap = -1
341 | return
342 |
343 | self.true_pos = ((gt_label == 1) & (pred == 1)).sum()
344 | self.false_pos = ((gt_label == 0) & (pred == 1)).sum()
345 | self.true_neg = ((gt_label == 0) & (pred == 0)).sum()
346 | self.false_neg = ((gt_label == 1) & (pred == 0)).sum()
347 |
348 | # Number of groundtruth positives & negatives.
349 | self.n_pos = self.true_pos + self.false_neg
350 | self.n_neg = self.false_pos + self.true_neg
351 |
352 | # AP score.
353 | self.ap = -1
354 |
355 | def get_recall(self):
356 | """Computes recall.
357 | """
358 | n_pos_pred = self.true_pos + self.false_pos
359 | if n_pos_pred == 0:
360 | # Model makes 0 positive prediction.
361 | # This is a special case: we fall back to precision = 1 and recall = 0.
362 | return 0
363 |
364 | if self.n_pos > 0:
365 | return self.true_pos / self.n_pos
366 | return -1
367 |
368 | def get_tnr(self):
369 | """Computes true negative rate.
370 | """
371 | if self.n_neg > 0:
372 | return self.true_neg / self.n_neg
373 | return -1
374 |
375 | def get_acc(self):
376 | """Computes accuracy.
377 | """
378 | if self.n_pos + self.n_neg > 0:
379 | return (self.true_pos + self.true_neg) / (self.n_pos + self.n_neg)
380 | return -1
381 |
382 | def get_bacc(self):
383 | """Computes balanced accuracy.
384 | """
385 | recall = self.get_recall()
386 | tnr = self.get_tnr()
387 | if recall == -1 or tnr == -1:
388 | return -1
389 | return (recall + tnr) / 2.0
390 |
391 | def get_precision(self):
392 | """Computes precision.
393 | """
394 | n_pos_pred = self.true_pos + self.false_pos
395 | if n_pos_pred == 0:
396 | # Model makes 0 positive prediction.
397 | # This is a special case: we fall back to precision = 1 and recall = 0.
398 | return 1
399 | return self.true_pos / n_pos_pred
400 |
401 | def get_f1(self):
402 | """Computes F1.
403 | """
404 | recall = self.get_recall()
405 | precision = self.get_precision()
406 |
407 | if precision + recall > 0:
408 | return 2 * precision * recall / (precision + recall)
409 | return 0
410 |
411 |
412 | def preprocess_pos_neg(targets):
413 | """
414 | Preprocess positive and negative attribute indices to create ground truth labels.
415 |
416 | Args:
417 | targets (list of dict): List of target dictionaries, each containing:
418 | - 'pos_att_classes' (torch.Tensor): Positive attribute indices (binary format).
419 | - 'neg_att_classes' (torch.Tensor): Negative attribute indices (binary format).
420 |
421 | Returns:
422 | list: List of ground truth labels where:
423 | 1 indicates positive,
424 | 0 indicates negative,
425 | 2 indicates neutral/unlabeled.
426 | """
427 | pos_atts = torch.cat([target['pos_att_classes'] for target in targets], dim=0)
428 | neg_atts = torch.cat([target['neg_att_classes'] for target in targets], dim=0)
429 |
430 | gts = []
431 | for pos_att, neg_att in zip(pos_atts, neg_atts):
432 | gt = 2 * np.ones(pos_att.unsqueeze(0).shape, dtype=np.int64)
433 | gt[:, pos_att.cpu() == 1] = np.int64(1)
434 | gt[:, neg_att.cpu() == 1] = np.int64(0)
435 | gts.extend(gt)
436 |
437 | return gts
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | from datasets.vaw_eval import Evaluator, preprocess_pos_neg
2 | from torch.cuda.amp import autocast
3 | from typing import Iterable
4 | import util.misc as utils
5 | import numpy as np
6 | import torch
7 | import math
8 | import json
9 | import sys
10 |
11 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0, args=None):
12 | model.train()
13 | criterion.train()
14 | metric_logger = utils.MetricLogger(delimiter=" ")
15 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
16 | header = 'Epoch: [{}]'.format(epoch)
17 | print_freq = 10
18 | for samples, crop_samples, mask_samples, targets in metric_logger.log_every(data_loader, print_freq, header):
19 | samples = samples.tensors.to(device)
20 | crop_samples = [crop_sample.to(device) for crop_sample in crop_samples]
21 | crop_masks = [crop_mask.to(device) for crop_mask in mask_samples]
22 | targets = [{k: v.to(device) if k != 'obj_names' and type(v) != int else v for k, v in t.items()} for t in targets]
23 | inputs = [
24 | {
25 | "samples": samples,
26 | "crop_samples": crop_samples,
27 | "crop_masks": crop_masks
28 | }
29 | ]
30 | with autocast():
31 | outputs = model(inputs, targets, args)
32 |
33 | loss_dict = criterion(outputs, targets)
34 | weight_dict = criterion.weight_dict
35 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
36 | loss_dict_reduced = utils.reduce_dict(loss_dict)
37 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()}
38 | loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
39 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
40 | loss_value = losses_reduced_scaled.item()
41 |
42 | if not math.isfinite(loss_value):
43 | print("Loss is {}, stopping training".format(loss_value))
44 | print(loss_dict_reduced)
45 | sys.exit(1)
46 |
47 | optimizer.zero_grad()
48 | losses.backward()
49 | if max_norm > 0:
50 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
51 | optimizer.step()
52 |
53 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
54 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
55 |
56 | metric_logger.synchronize_between_processes()
57 | print("Averaged stats:", metric_logger)
58 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
59 |
60 | @torch.no_grad()
61 | def evaluate_att(model, data_loader, device, args=None):
62 | metric_logger = utils.MetricLogger(delimiter=" ")
63 | print_freq = 10
64 | model.eval()
65 | header = 'Test:'
66 | preds = []
67 | gts = []
68 | for samples, crop_samples, crop_masks, targets in metric_logger.log_every(data_loader, print_freq, header):
69 | samples = samples.tensors.to(device)
70 | crop_samples = [crop_sample.to(device) for crop_sample in crop_samples]
71 | crop_masks = [crop_mask.to(device) for crop_mask in crop_masks]
72 | inputs = [{"samples": samples, "crop_samples": crop_samples, "crop_masks": crop_masks}]
73 | targets = [{k: v.to(device) if k != 'obj_names' and type(v) != int else v for k, v in t.items()} for t in targets]
74 | with autocast():
75 | outputs = model(inputs, targets, args)['pred_logits']
76 | outputs = outputs.sigmoid()
77 | preds.extend(outputs.detach().cpu().numpy())
78 | gt = preprocess_pos_neg(targets)
79 | gts.extend(gt)
80 |
81 | metric_logger.synchronize_between_processes()
82 | preds = torch.cat(utils.all_gather(torch.from_numpy(np.array(preds))))
83 | annos = torch.cat(utils.all_gather(torch.from_numpy(np.array(gts))))
84 | evaluator = Evaluator(args.fpath_attribute_index, args.fpath_head_tail)
85 | scores_per_class = evaluator.evaluate(preds, annos)
86 | CATEGORIES = ['all', 'head', 'medium', 'tail']
87 | stats = {f'mAP_{category}': scores_per_class[category]['ap'] for category in CATEGORIES}
88 | if args.mode == 'zero_shot':
89 | stats.update(compute_zero_shot_mAP(evaluator, args))
90 |
91 | return stats
92 |
93 | def compute_zero_shot_mAP(evaluator, args):
94 | base_novel_dict = json.load(open(args.base_novel_dict, 'r'))
95 | base_class = [v for _, v in base_novel_dict['base'].items()]
96 | novel_class = [v for _, v in base_novel_dict['novel'].items()]
97 | base_mAP = sum(evaluator.get_score_class(i_class).ap for i_class in base_class) / len(base_class)
98 | novel_mAP = sum(evaluator.get_score_class(i_class).ap for i_class in novel_class) / len(novel_class)
99 | return {'mAP_base': base_mAP, 'mAP_novel': novel_mAP}
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, DistributedSampler
2 | from engine import evaluate_att, train_one_epoch
3 | from datasets import build_dataset
4 | from models import build_model
5 | import util.misc as utils
6 | from pathlib import Path
7 | import numpy as np
8 | import argparse
9 | import datetime
10 | import random
11 | import json
12 | import time
13 | import torch
14 | import os
15 |
16 | def get_args_parser():
17 | parser = argparse.ArgumentParser('Set SugaFormer', add_help=False)
18 | parser.add_argument('--lr', default=1e-4, type=float)
19 | parser.add_argument('--lr_backbone', default=1e-5, type=float)
20 | parser.add_argument('--batch_size', default=2, type=int)
21 | parser.add_argument('--weight_decay', default=1e-4, type=float)
22 | parser.add_argument('--epochs', default=10, type=int)
23 | parser.add_argument('--lr_drop', default=8, type=int)
24 | parser.add_argument('--clip_max_norm', default=0.1, type=float, help="gradient clipping max norm")
25 | parser.add_argument('--pretrained', type=str, default='')
26 | parser.add_argument('--gamma_neg', default=4, type=int, help="gamma_neg for Assymloss")
27 | parser.add_argument('--clip', default=0.05, type=float)
28 | parser.add_argument('--position_embedding', default='learned', type=str, choices=('sine', 'learned'),
29 | help="Type of positional embedding to use on top of the image features")
30 | parser.add_argument('--freeze_backbone', action='store_true', help="freezing backbone flag")
31 |
32 | parser.add_argument('--dataset_file', default='', type=str)
33 | parser.add_argument('--vaw_path', default='data/vaw', type=str)
34 | parser.add_argument('--output_dir', default='', help="path where to save, empty for no saving")
35 |
36 | parser.add_argument('--device', default='cuda', help="device to use for training / testing")
37 | parser.add_argument('--seed', default=42, type=int)
38 | parser.add_argument('--resume', default='', help="resume from checkpoint")
39 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help="start epoch")
40 | parser.add_argument('--eval', action='store_true')
41 | parser.add_argument('--num_workers', default=4, type=int)
42 | parser.add_argument('--world_size', default=1, type=int, help="number of distributed processes")
43 | parser.add_argument('--dist_url', default='env://', help="url used to set up distributed training")
44 |
45 |
46 | parser.add_argument('--enc_layers', default=0, type=int, help="Number of encoding layers in the transformer")
47 | parser.add_argument('--dec_layers', default=3, type=int, help="Number of decoding layers in the transformer")
48 | parser.add_argument('--dim_feedforward', default=2048, type=int, help="Intermediate size of the feedforward layers in the transformer blocks")
49 | parser.add_argument('--hidden_dim', default=256, type=int, help="Size of the embeddings (dimension of the transformer)")
50 | parser.add_argument('--image_hidden_dim', default=1408, type=int, help="Visual Backbone output dimension")
51 | parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer")
52 | parser.add_argument('--nheads', default=8, type=int, help="Number of attention heads inside the transformer's attentions")
53 | parser.add_argument('--pre_norm', action='store_true')
54 |
55 | parser.add_argument('--num_obj_classes', default=2260, type=int, help="Number of object classes")
56 | parser.add_argument('--num_att_classes', default=620, type=int, help="Number of attribute classes")
57 | parser.add_argument('--att_loss_coef', default=1, type=float, help="attribute loss coefficient")
58 | parser.add_argument('--scr_coef', default=2, type=float, help="scr loss coefficient")
59 |
60 | parser.add_argument('--fpath_attribute_index', type=str, default='data/vaw/annotations/attribute_index.json')
61 | parser.add_argument('--fpath_head_tail', type=str, default='data/vaw/annotations/head_tail.json')
62 | parser.add_argument('--sc_feats', default='data/vaw/annotations/sc_embedding.pt', type=str, help="super-class text embedding")
63 | parser.add_argument('--att_feats', default='data/vaw/annotations/att_embedding.pt', type=str, help="attribute text embedding")
64 | parser.add_argument('--hierarchy', default='data/vaw/annotations/hierarchy.json', type=str, help="super-class to attribute hierarchy")
65 | parser.add_argument('--att_class_weight', default='data/vaw/annotations/att_class_weight.pt', type=str, help="weight for attribute classes")
66 | parser.add_argument('--base_novel_dict', default="data/vaw/annotations/base_novel_dict.json", type=str, help="base, novel index in attribute classes")
67 |
68 | parser.add_argument('--mode', default='', type=str, help="", choices=('zero_shot', 'supervised'))
69 | parser.add_argument('--use_scr', default='data/vaw/annotations/scr_tokens/', type=str, help="[MASK] feature for scr loss")
70 | parser.add_argument('--zrse', action='store_true', help="use zero-shot retrieval-based score enhancement")
71 | parser.add_argument('--zrse_scale', default=2, type=float)
72 | parser.add_argument('--ztopk',default=2, type=int)
73 |
74 | return parser
75 |
76 |
77 |
78 | def main(args):
79 | utils.init_distributed_mode(args)
80 | device = torch.device(args.device)
81 | seed = args.seed + utils.get_rank()
82 | torch.manual_seed(seed)
83 | np.random.seed(seed)
84 | random.seed(seed)
85 | torch.backends.cudnn.benchmark = False
86 | os.environ["PYTHONHASHSEED"] = str(seed)
87 |
88 | model, criterion = build_model(args)
89 | model.to(device)
90 | model_without_ddp = model
91 |
92 | if args.distributed:
93 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
94 | model_without_ddp = model.module
95 |
96 | param_dicts = [
97 | {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
98 | {
99 | "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
100 | "lr": args.lr_backbone,
101 | },
102 | ]
103 |
104 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
105 | weight_decay=args.weight_decay)
106 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
107 |
108 | dataset_train = build_dataset(image_set='train', args=args)
109 | dataset_test = build_dataset(image_set='test', args=args)
110 |
111 | if args.distributed:
112 | sampler_train = DistributedSampler(dataset_train)
113 | sampler_test = DistributedSampler(dataset_test, shuffle=False)
114 |
115 | else:
116 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
117 | sampler_test = torch.utils.data.SequentialSampler(dataset_test)
118 |
119 | batch_sampler_train = torch.utils.data.BatchSampler(
120 | sampler_train, args.batch_size, drop_last=True)
121 |
122 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
123 | collate_fn=utils.collate_fn, num_workers=args.num_workers)
124 | data_loader_test = DataLoader(dataset_test, args.batch_size, sampler=sampler_test,
125 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
126 |
127 | output_dir = Path(args.output_dir)
128 | if args.resume:
129 | if args.resume.startswith("https"):
130 | checkpoint = torch.hub.load_state_dict_from_url(
131 | args.resume, map_location='cpu', check_hash=True)
132 | else:
133 | checkpoint = torch.load(args.resume, map_location=("cpu"))
134 |
135 | model_without_ddp.load_state_dict(checkpoint['model'])
136 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
137 | optimizer.load_state_dict(checkpoint['optimizer'])
138 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
139 | args.start_epoch = checkpoint['epoch'] + 1
140 |
141 | if args.pretrained:
142 | if args.eval:
143 | checkpoint = torch.load(args.pretrained, map_location="cpu")
144 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
145 |
146 | test_stats = evaluate_att(model, data_loader_test, device, args)
147 | log_stats = {**{f'test_{k}': v for k, v in test_stats.items()}}
148 |
149 | if args.output_dir and utils.is_main_process():
150 | with (output_dir / "log.txt").open("a") as f:
151 | f.write(json.dumps(log_stats) + "\n")
152 | return
153 |
154 | print("Start training")
155 | start_time = time.time()
156 | for epoch in range(args.start_epoch, args.epochs):
157 |
158 | if args.distributed:
159 | sampler_train.set_epoch(epoch)
160 |
161 | train_one_epoch(
162 | model, criterion, data_loader_train, optimizer, device, epoch,
163 | args.clip_max_norm, args)
164 |
165 | lr_scheduler.step()
166 | if args.output_dir:
167 | checkpoint_paths = [output_dir / 'checkpoint.pth']
168 | if (epoch + 1) % 5 == 0:
169 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
170 | for checkpoint_path in checkpoint_paths:
171 | utils.save_on_master({
172 | 'model': model_without_ddp.state_dict(),
173 | 'optimizer': optimizer.state_dict(),
174 | 'lr_scheduler': lr_scheduler.state_dict(),
175 | 'epoch': epoch,
176 | 'args': args,
177 | }, checkpoint_path)
178 |
179 | total_time = time.time() - start_time
180 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
181 | print('Training time {}'.format(total_time_str))
182 |
183 |
184 | if __name__ == "__main__":
185 | parser = argparse.ArgumentParser('SugaFormer training and evaluation script', parents=[get_args_parser()])
186 | args = parser.parse_args()
187 | if args.output_dir:
188 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
189 | main(args)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from .sugaformer import build
3 |
4 | def build_model(args):
5 | return build(args)
6 |
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | from util.misc import NestedTensor, nested_tensor_from_tensor_list
2 | from .lavis.models import load_model_and_preprocess
3 | from .position_encoding import build_position_encoding
4 | from typing import List
5 | from torch import nn
6 | import torch
7 | import math
8 |
9 | class Joiner(nn.Sequential):
10 | def __init__(self, backbone, position_embedding):
11 | super().__init__(backbone, position_embedding)
12 |
13 | def forward(self, tensor_list: NestedTensor):
14 | xs_list = self[0](tensor_list.tensors)
15 | out: List[NestedTensor] = []
16 | pos = []
17 | for xs in xs_list:
18 | xs = xs[:,1:,:]
19 | B, D, C = xs.shape
20 | D_s = int(math.sqrt(D))
21 | x = xs.permute(0,2,1).view(B,C,D_s,D_s)
22 | x = nested_tensor_from_tensor_list([x_s for x_s in x])
23 | out.append(x)
24 | pos.append(self[1](x).to(x.tensors.dtype))
25 | return out, pos
26 |
27 | def build_blip_backbone(args):
28 | position_embedding = build_position_encoding(args)
29 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
30 | model = load_model_and_preprocess(name="blip2", model_type="pretrain", device=device)
31 | visual_encoder = model.visual_encoder
32 | visual_encoder = Joiner(visual_encoder, position_embedding)
33 | return visual_encoder, model
34 |
--------------------------------------------------------------------------------
/models/lavis/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from .models import BaseModel, Blip2Base, Blip2Qformer
9 | from .common.registry import registry
10 | from .processors import BaseProcessor
11 | from omegaconf import OmegaConf
12 | import sys
13 | import os
14 |
15 | # Exported symbols for the package
16 | __all__ = ["registry", "BaseModel", "Blip2Base", "Blip2Qformer", "BaseProcessor"]
17 |
18 | root_dir = os.path.dirname(os.path.abspath(__file__))
19 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
20 |
21 | registry.register_path("library_root", root_dir)
22 | repo_root = os.path.join(root_dir, "..")
23 | registry.register_path("repo_root", repo_root)
24 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
25 | registry.register_path("cache_root", cache_root)
26 |
27 | registry.register("MAX_INT", sys.maxsize)
28 | registry.register("SPLIT_NAMES", ["train", "val", "test"])
29 |
--------------------------------------------------------------------------------
/models/lavis/common/dist_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import functools
10 | import os
11 |
12 | import torch
13 | import torch.distributed as dist
14 | import timm.models.hub as timm_hub
15 |
16 |
17 | def setup_for_distributed(is_master):
18 | """
19 | This function disables printing when not in master process
20 | """
21 | import builtins as __builtin__
22 |
23 | builtin_print = __builtin__.print
24 |
25 | def print(*args, **kwargs):
26 | force = kwargs.pop("force", False)
27 | if is_master or force:
28 | builtin_print(*args, **kwargs)
29 |
30 | __builtin__.print = print
31 |
32 |
33 | def is_dist_avail_and_initialized():
34 | if not dist.is_available():
35 | return False
36 | if not dist.is_initialized():
37 | return False
38 | return True
39 |
40 |
41 | def get_world_size():
42 | if not is_dist_avail_and_initialized():
43 | return 1
44 | return dist.get_world_size()
45 |
46 |
47 | def get_rank():
48 | if not is_dist_avail_and_initialized():
49 | return 0
50 | return dist.get_rank()
51 |
52 |
53 | def is_main_process():
54 | return get_rank() == 0
55 |
56 |
57 | def init_distributed_mode(args):
58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59 | args.rank = int(os.environ["RANK"])
60 | args.world_size = int(os.environ["WORLD_SIZE"])
61 | args.gpu = int(os.environ["LOCAL_RANK"])
62 | elif "SLURM_PROCID" in os.environ:
63 | args.rank = int(os.environ["SLURM_PROCID"])
64 | args.gpu = args.rank % torch.cuda.device_count()
65 | else:
66 | print("Not using distributed mode")
67 | args.distributed = False
68 | return
69 |
70 | args.distributed = True
71 |
72 | torch.cuda.set_device(args.gpu)
73 | args.dist_backend = "nccl"
74 | print(
75 | "| distributed init (rank {}, world {}): {}".format(
76 | args.rank, args.world_size, args.dist_url
77 | ),
78 | flush=True,
79 | )
80 | torch.distributed.init_process_group(
81 | backend=args.dist_backend,
82 | init_method=args.dist_url,
83 | world_size=args.world_size,
84 | rank=args.rank,
85 | timeout=datetime.timedelta(
86 | days=365
87 | ), # allow auto-downloading and de-compressing
88 | )
89 | torch.distributed.barrier()
90 | setup_for_distributed(args.rank == 0)
91 |
92 |
93 | def get_dist_info():
94 | if torch.__version__ < "1.0":
95 | initialized = dist._initialized
96 | else:
97 | initialized = dist.is_initialized()
98 | if initialized:
99 | rank = dist.get_rank()
100 | world_size = dist.get_world_size()
101 | else: # non-distributed training
102 | rank = 0
103 | world_size = 1
104 | return rank, world_size
105 |
106 |
107 | def main_process(func):
108 | @functools.wraps(func)
109 | def wrapper(*args, **kwargs):
110 | rank, _ = get_dist_info()
111 | if rank == 0:
112 | return func(*args, **kwargs)
113 |
114 | return wrapper
115 |
116 |
117 | def download_cached_file(url, check_hash=True, progress=False):
118 | """
119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121 | """
122 |
123 | def get_cached_file_path():
124 | # a hack to sync the file path across processes
125 | parts = torch.hub.urlparse(url)
126 | filename = os.path.basename(parts.path)
127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128 |
129 | return cached_file
130 |
131 | if is_main_process():
132 | timm_hub.download_cached_file(url, check_hash, progress)
133 |
134 | if is_dist_avail_and_initialized():
135 | dist.barrier()
136 |
137 | return get_cached_file_path()
138 |
--------------------------------------------------------------------------------
/models/lavis/common/gradcam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from scipy.ndimage import filters
4 | from skimage import transform as skimage_transform
5 |
6 |
7 | def getAttMap(img, attMap, blur=True, overlap=True):
8 | attMap -= attMap.min()
9 | if attMap.max() > 0:
10 | attMap /= attMap.max()
11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12 | if blur:
13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14 | attMap -= attMap.min()
15 | attMap /= attMap.max()
16 | cmap = plt.get_cmap("jet")
17 | attMapV = cmap(attMap)
18 | attMapV = np.delete(attMapV, 3, 2)
19 | if overlap:
20 | attMap = (
21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23 | )
24 | return attMap
25 |
--------------------------------------------------------------------------------
/models/lavis/common/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import logging
10 | import time
11 | from collections import defaultdict, deque
12 |
13 | import torch
14 | import torch.distributed as dist
15 |
16 | from lavis.common import dist_utils
17 |
18 |
19 | class SmoothedValue(object):
20 | """Track a series of values and provide access to smoothed values over a
21 | window or the global series average.
22 | """
23 |
24 | def __init__(self, window_size=20, fmt=None):
25 | if fmt is None:
26 | fmt = "{median:.4f} ({global_avg:.4f})"
27 | self.deque = deque(maxlen=window_size)
28 | self.total = 0.0
29 | self.count = 0
30 | self.fmt = fmt
31 |
32 | def update(self, value, n=1):
33 | self.deque.append(value)
34 | self.count += n
35 | self.total += value * n
36 |
37 | def synchronize_between_processes(self):
38 | """
39 | Warning: does not synchronize the deque!
40 | """
41 | if not dist_utils.is_dist_avail_and_initialized():
42 | return
43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median,
75 | avg=self.avg,
76 | global_avg=self.global_avg,
77 | max=self.max,
78 | value=self.value,
79 | )
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, **kwargs):
88 | for k, v in kwargs.items():
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError(
100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101 | )
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append("{}: {}".format(name, str(meter)))
107 | return self.delimiter.join(loss_str)
108 |
109 | def global_avg(self):
110 | loss_str = []
111 | for name, meter in self.meters.items():
112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113 | return self.delimiter.join(loss_str)
114 |
115 | def synchronize_between_processes(self):
116 | for meter in self.meters.values():
117 | meter.synchronize_between_processes()
118 |
119 | def add_meter(self, name, meter):
120 | self.meters[name] = meter
121 |
122 | def log_every(self, iterable, print_freq, header=None):
123 | i = 0
124 | if not header:
125 | header = ""
126 | start_time = time.time()
127 | end = time.time()
128 | iter_time = SmoothedValue(fmt="{avg:.4f}")
129 | data_time = SmoothedValue(fmt="{avg:.4f}")
130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131 | log_msg = [
132 | header,
133 | "[{0" + space_fmt + "}/{1}]",
134 | "eta: {eta}",
135 | "{meters}",
136 | "time: {time}",
137 | "data: {data}",
138 | ]
139 | if torch.cuda.is_available():
140 | log_msg.append("max mem: {memory:.0f}")
141 | log_msg = self.delimiter.join(log_msg)
142 | MB = 1024.0 * 1024.0
143 | for obj in iterable:
144 | data_time.update(time.time() - end)
145 | yield obj
146 | iter_time.update(time.time() - end)
147 | if i % print_freq == 0 or i == len(iterable) - 1:
148 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150 | if torch.cuda.is_available():
151 | print(
152 | log_msg.format(
153 | i,
154 | len(iterable),
155 | eta=eta_string,
156 | meters=str(self),
157 | time=str(iter_time),
158 | data=str(data_time),
159 | memory=torch.cuda.max_memory_allocated() / MB,
160 | )
161 | )
162 | else:
163 | print(
164 | log_msg.format(
165 | i,
166 | len(iterable),
167 | eta=eta_string,
168 | meters=str(self),
169 | time=str(iter_time),
170 | data=str(data_time),
171 | )
172 | )
173 | i += 1
174 | end = time.time()
175 | total_time = time.time() - start_time
176 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177 | print(
178 | "{} Total time: {} ({:.4f} s / it)".format(
179 | header, total_time_str, total_time / len(iterable)
180 | )
181 | )
182 |
183 |
184 | class AttrDict(dict):
185 | def __init__(self, *args, **kwargs):
186 | super(AttrDict, self).__init__(*args, **kwargs)
187 | self.__dict__ = self
188 |
189 |
190 | def setup_logger():
191 | logging.basicConfig(
192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193 | format="%(asctime)s [%(levelname)s] %(message)s",
194 | handlers=[logging.StreamHandler()],
195 | )
196 |
--------------------------------------------------------------------------------
/models/lavis/common/optims.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import math
9 |
10 | from lavis.common.registry import registry
11 |
12 |
13 | @registry.register_lr_scheduler("linear_warmup_step_lr")
14 | class LinearWarmupStepLRScheduler:
15 | def __init__(
16 | self,
17 | optimizer,
18 | max_epoch,
19 | min_lr,
20 | init_lr,
21 | decay_rate=1,
22 | warmup_start_lr=-1,
23 | warmup_steps=0,
24 | **kwargs
25 | ):
26 | self.optimizer = optimizer
27 |
28 | self.max_epoch = max_epoch
29 | self.min_lr = min_lr
30 |
31 | self.decay_rate = decay_rate
32 |
33 | self.init_lr = init_lr
34 | self.warmup_steps = warmup_steps
35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36 |
37 | def step(self, cur_epoch, cur_step):
38 | if cur_epoch == 0:
39 | warmup_lr_schedule(
40 | step=cur_step,
41 | optimizer=self.optimizer,
42 | max_step=self.warmup_steps,
43 | init_lr=self.warmup_start_lr,
44 | max_lr=self.init_lr,
45 | )
46 | else:
47 | step_lr_schedule(
48 | epoch=cur_epoch,
49 | optimizer=self.optimizer,
50 | init_lr=self.init_lr,
51 | min_lr=self.min_lr,
52 | decay_rate=self.decay_rate,
53 | )
54 |
55 |
56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57 | class LinearWarmupCosineLRScheduler:
58 | def __init__(
59 | self,
60 | optimizer,
61 | max_epoch,
62 | min_lr,
63 | init_lr,
64 | warmup_steps=0,
65 | warmup_start_lr=-1,
66 | **kwargs
67 | ):
68 | self.optimizer = optimizer
69 |
70 | self.max_epoch = max_epoch
71 | self.min_lr = min_lr
72 |
73 | self.init_lr = init_lr
74 | self.warmup_steps = warmup_steps
75 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
76 |
77 | def step(self, cur_epoch, cur_step):
78 | # assuming the warmup iters less than one epoch
79 | if cur_epoch == 0:
80 | warmup_lr_schedule(
81 | step=cur_step,
82 | optimizer=self.optimizer,
83 | max_step=self.warmup_steps,
84 | init_lr=self.warmup_start_lr,
85 | max_lr=self.init_lr,
86 | )
87 | else:
88 | cosine_lr_schedule(
89 | epoch=cur_epoch,
90 | optimizer=self.optimizer,
91 | max_epoch=self.max_epoch,
92 | init_lr=self.init_lr,
93 | min_lr=self.min_lr,
94 | )
95 |
96 |
97 | @registry.register_lr_scheduler("constant_lr")
98 | class ConstantLRScheduler:
99 | def __init__(self, optimizer, init_lr, warmup_start_lr=-1, warmup_steps=0, **kwargs):
100 | self.optimizer = optimizer
101 | self.lr = init_lr
102 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
103 | self.warmup_steps = warmup_steps
104 |
105 | def step(self, cur_epoch, cur_step):
106 | if cur_epoch == 0:
107 | warmup_lr_schedule(
108 | step=cur_step,
109 | optimizer=self.optimizer,
110 | max_step=self.warmup_steps,
111 | init_lr=self.warmup_start_lr,
112 | max_lr=self.lr,
113 | )
114 | else:
115 | for param_group in self.optimizer.param_groups:
116 | param_group["lr"] = self.lr
117 |
118 |
119 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
120 | """Decay the learning rate"""
121 | lr = (init_lr - min_lr) * 0.5 * (
122 | 1.0 + math.cos(math.pi * epoch / max_epoch)
123 | ) + min_lr
124 | for param_group in optimizer.param_groups:
125 | param_group["lr"] = lr
126 |
127 |
128 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
129 | """Warmup the learning rate"""
130 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
131 | for param_group in optimizer.param_groups:
132 | param_group["lr"] = lr
133 |
134 |
135 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
136 | """Decay the learning rate"""
137 | lr = max(min_lr, init_lr * (decay_rate**epoch))
138 | for param_group in optimizer.param_groups:
139 | param_group["lr"] = lr
140 |
--------------------------------------------------------------------------------
/models/lavis/common/registry.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 |
9 | class Registry:
10 | mapping = {
11 | "builder_name_mapping": {},
12 | "task_name_mapping": {},
13 | "processor_name_mapping": {},
14 | "model_name_mapping": {},
15 | "lr_scheduler_name_mapping": {},
16 | "runner_name_mapping": {},
17 | "state": {},
18 | "paths": {},
19 | }
20 |
21 | # @classmethod
22 | # def register_builder(cls, name):
23 | # r"""Register a dataset builder to registry with key 'name'
24 |
25 | # Args:
26 | # name: Key with which the builder will be registered.
27 |
28 | # Usage:
29 |
30 | # from lavis.common.registry import registry
31 | # from lavis.datasets.base_dataset_builder import BaseDatasetBuilder
32 | # """
33 |
34 | # def wrap(builder_cls):
35 | # from LAVIS.lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36 |
37 | # assert issubclass(
38 | # builder_cls, BaseDatasetBuilder
39 | # ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40 | # builder_cls
41 | # )
42 | # if name in cls.mapping["builder_name_mapping"]:
43 | # raise KeyError(
44 | # "Name '{}' already registered for {}.".format(
45 | # name, cls.mapping["builder_name_mapping"][name]
46 | # )
47 | # )
48 | # cls.mapping["builder_name_mapping"][name] = builder_cls
49 | # return builder_cls
50 |
51 | # return wrap
52 |
53 | @classmethod
54 | def register_task(cls, name):
55 | r"""Register a task to registry with key 'name'
56 |
57 | Args:
58 | name: Key with which the task will be registered.
59 |
60 | Usage:
61 |
62 | from lavis.common.registry import registry
63 | """
64 |
65 | def wrap(task_cls):
66 | from models.lavis.tasks.base_task import BaseTask
67 |
68 | assert issubclass(
69 | task_cls, BaseTask
70 | ), "All tasks must inherit BaseTask class"
71 | if name in cls.mapping["task_name_mapping"]:
72 | raise KeyError(
73 | "Name '{}' already registered for {}.".format(
74 | name, cls.mapping["task_name_mapping"][name]
75 | )
76 | )
77 | cls.mapping["task_name_mapping"][name] = task_cls
78 | return task_cls
79 |
80 | return wrap
81 |
82 | @classmethod
83 | def register_model(cls, name):
84 | r"""Register a task to registry with key 'name'
85 |
86 | Args:
87 | name: Key with which the task will be registered.
88 |
89 | Usage:
90 |
91 | from lavis.common.registry import registry
92 | """
93 |
94 | def wrap(model_cls):
95 | from models.lavis.models import BaseModel
96 | assert issubclass(
97 | model_cls, BaseModel
98 | ), "All models must inherit BaseModel class"
99 | if name in cls.mapping["model_name_mapping"]:
100 | raise KeyError(
101 | "Name '{}' already registered for {}.".format(
102 | name, cls.mapping["model_name_mapping"][name]
103 | )
104 | )
105 | cls.mapping["model_name_mapping"][name] = model_cls
106 | return model_cls
107 |
108 | return wrap
109 |
110 | @classmethod
111 | def register_processor(cls, name):
112 | r"""Register a processor to registry with key 'name'
113 |
114 | Args:
115 | name: Key with which the task will be registered.
116 |
117 | Usage:
118 |
119 | from lavis.common.registry import registry
120 | """
121 |
122 | def wrap(processor_cls):
123 | from models.lavis.processors import BaseProcessor
124 | assert issubclass(
125 | processor_cls, BaseProcessor
126 | ), "All processors must inherit BaseProcessor class"
127 | if name in cls.mapping["processor_name_mapping"]:
128 | raise KeyError(
129 | "Name '{}' already registered for {}.".format(
130 | name, cls.mapping["processor_name_mapping"][name]
131 | )
132 | )
133 | cls.mapping["processor_name_mapping"][name] = processor_cls
134 | return processor_cls
135 |
136 | return wrap
137 |
138 | @classmethod
139 | def register_lr_scheduler(cls, name):
140 | r"""Register a model to registry with key 'name'
141 |
142 | Args:
143 | name: Key with which the task will be registered.
144 |
145 | Usage:
146 |
147 | from lavis.common.registry import registry
148 | """
149 |
150 | def wrap(lr_sched_cls):
151 | if name in cls.mapping["lr_scheduler_name_mapping"]:
152 | raise KeyError(
153 | "Name '{}' already registered for {}.".format(
154 | name, cls.mapping["lr_scheduler_name_mapping"][name]
155 | )
156 | )
157 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
158 | return lr_sched_cls
159 |
160 | return wrap
161 |
162 | @classmethod
163 | def register_runner(cls, name):
164 | r"""Register a model to registry with key 'name'
165 |
166 | Args:
167 | name: Key with which the task will be registered.
168 |
169 | Usage:
170 |
171 | from lavis.common.registry import registry
172 | """
173 |
174 | def wrap(runner_cls):
175 | if name in cls.mapping["runner_name_mapping"]:
176 | raise KeyError(
177 | "Name '{}' already registered for {}.".format(
178 | name, cls.mapping["runner_name_mapping"][name]
179 | )
180 | )
181 | cls.mapping["runner_name_mapping"][name] = runner_cls
182 | return runner_cls
183 |
184 | return wrap
185 |
186 | @classmethod
187 | def register_path(cls, name, path):
188 | r"""Register a path to registry with key 'name'
189 |
190 | Args:
191 | name: Key with which the path will be registered.
192 |
193 | Usage:
194 |
195 | from lavis.common.registry import registry
196 | """
197 | assert isinstance(path, str), "All path must be str."
198 | if name in cls.mapping["paths"]:
199 | raise KeyError("Name '{}' already registered.".format(name))
200 | cls.mapping["paths"][name] = path
201 |
202 | @classmethod
203 | def register(cls, name, obj):
204 | r"""Register an item to registry with key 'name'
205 |
206 | Args:
207 | name: Key with which the item will be registered.
208 |
209 | Usage::
210 |
211 | from lavis.common.registry import registry
212 |
213 | registry.register("config", {})
214 | """
215 | path = name.split(".")
216 | current = cls.mapping["state"]
217 |
218 | for part in path[:-1]:
219 | if part not in current:
220 | current[part] = {}
221 | current = current[part]
222 |
223 | current[path[-1]] = obj
224 |
225 | # @classmethod
226 | # def get_trainer_class(cls, name):
227 | # return cls.mapping["trainer_name_mapping"].get(name, None)
228 |
229 | @classmethod
230 | def get_builder_class(cls, name):
231 | return cls.mapping["builder_name_mapping"].get(name, None)
232 |
233 | @classmethod
234 | def get_model_class(cls, name):
235 | return cls.mapping["model_name_mapping"].get(name, None)
236 |
237 | @classmethod
238 | def get_task_class(cls, name):
239 | return cls.mapping["task_name_mapping"].get(name, None)
240 |
241 | @classmethod
242 | def get_processor_class(cls, name):
243 | return cls.mapping["processor_name_mapping"].get(name, None)
244 |
245 | @classmethod
246 | def get_lr_scheduler_class(cls, name):
247 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
248 |
249 | @classmethod
250 | def get_runner_class(cls, name):
251 | return cls.mapping["runner_name_mapping"].get(name, None)
252 |
253 | @classmethod
254 | def list_runners(cls):
255 | return sorted(cls.mapping["runner_name_mapping"].keys())
256 |
257 | @classmethod
258 | def list_models(cls):
259 | return sorted(cls.mapping["model_name_mapping"].keys())
260 |
261 | @classmethod
262 | def list_tasks(cls):
263 | return sorted(cls.mapping["task_name_mapping"].keys())
264 |
265 | @classmethod
266 | def list_processors(cls):
267 | return sorted(cls.mapping["processor_name_mapping"].keys())
268 |
269 | @classmethod
270 | def list_lr_schedulers(cls):
271 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
272 |
273 | @classmethod
274 | def list_datasets(cls):
275 | return sorted(cls.mapping["builder_name_mapping"].keys())
276 |
277 | @classmethod
278 | def get_path(cls, name):
279 | return cls.mapping["paths"].get(name, None)
280 |
281 | @classmethod
282 | def get(cls, name, default=None, no_warning=False):
283 | r"""Get an item from registry with key 'name'
284 |
285 | Args:
286 | name (string): Key whose value needs to be retrieved.
287 | default: If passed and key is not in registry, default value will
288 | be returned with a warning. Default: None
289 | no_warning (bool): If passed as True, warning when key doesn't exist
290 | will not be generated. Useful for MMF's
291 | internal operations. Default: False
292 | """
293 | original_name = name
294 | name = name.split(".")
295 | value = cls.mapping["state"]
296 | for subname in name:
297 | value = value.get(subname, default)
298 | if value is default:
299 | break
300 |
301 | if (
302 | "writer" in cls.mapping["state"]
303 | and value == default
304 | and no_warning is False
305 | ):
306 | cls.mapping["state"]["writer"].warning(
307 | "Key {} is not present in registry, returning default value "
308 | "of {}".format(original_name, default)
309 | )
310 | return value
311 |
312 | @classmethod
313 | def unregister(cls, name):
314 | r"""Remove an item from registry with key 'name'
315 |
316 | Args:
317 | name: Key which needs to be removed.
318 | Usage::
319 |
320 | from mmf.common.registry import registry
321 |
322 | config = registry.unregister("config")
323 | """
324 | return cls.mapping["state"].pop(name, None)
325 |
326 |
327 | registry = Registry()
328 |
--------------------------------------------------------------------------------
/models/lavis/common/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import io
9 | import json
10 | import logging
11 | import os
12 | import pickle
13 | import re
14 | import shutil
15 | import tarfile
16 | import urllib
17 | import urllib.error
18 | import urllib.request
19 | from typing import Optional
20 | from urllib.parse import urlparse
21 |
22 | import numpy as np
23 | import pandas as pd
24 | import yaml
25 | from iopath.common.download import download
26 | from iopath.common.file_io import file_lock, g_pathmgr
27 | from lavis.common.dist_utils import download_cached_file
28 | from lavis.common.registry import registry
29 | from torch.utils.model_zoo import tqdm
30 | from torchvision.datasets.utils import (
31 | check_integrity,
32 | download_file_from_google_drive,
33 | extract_archive,
34 | )
35 |
36 |
37 | def now():
38 | from datetime import datetime
39 |
40 | return datetime.now().strftime("%Y%m%d%H%M")[:-1]
41 |
42 |
43 | def is_url(url_or_filename):
44 | parsed = urlparse(url_or_filename)
45 | return parsed.scheme in ("http", "https")
46 |
47 |
48 | def get_cache_path(rel_path):
49 | return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
50 |
51 |
52 | def get_abs_path(rel_path):
53 | return os.path.join(registry.get_path("library_root"), rel_path)
54 |
55 |
56 | def load_json(filename):
57 | with open(filename, "r") as f:
58 | return json.load(f)
59 |
60 |
61 | # The following are adapted from torchvision and vissl
62 | # torchvision: https://github.com/pytorch/vision
63 | # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
64 |
65 |
66 | def makedir(dir_path):
67 | """
68 | Create the directory if it does not exist.
69 | """
70 | is_success = False
71 | try:
72 | if not g_pathmgr.exists(dir_path):
73 | g_pathmgr.mkdirs(dir_path)
74 | is_success = True
75 | except BaseException:
76 | print(f"Error creating directory: {dir_path}")
77 | return is_success
78 |
79 |
80 | def get_redirected_url(url: str):
81 | """
82 | Given a URL, returns the URL it redirects to or the
83 | original URL in case of no indirection
84 | """
85 | import requests
86 |
87 | with requests.Session() as session:
88 | with session.get(url, stream=True, allow_redirects=True) as response:
89 | if response.history:
90 | return response.url
91 | else:
92 | return url
93 |
94 |
95 | def to_google_drive_download_url(view_url: str) -> str:
96 | """
97 | Utility function to transform a view URL of google drive
98 | to a download URL for google drive
99 | Example input:
100 | https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
101 | Example output:
102 | https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
103 | """
104 | splits = view_url.split("/")
105 | assert splits[-1] == "view"
106 | file_id = splits[-2]
107 | return f"https://drive.google.com/uc?export=download&id={file_id}"
108 |
109 |
110 | def download_google_drive_url(url: str, output_path: str, output_file_name: str):
111 | """
112 | Download a file from google drive
113 | Downloading an URL from google drive requires confirmation when
114 | the file of the size is too big (google drive notifies that
115 | anti-viral checks cannot be performed on such files)
116 | """
117 | import requests
118 |
119 | with requests.Session() as session:
120 |
121 | # First get the confirmation token and append it to the URL
122 | with session.get(url, stream=True, allow_redirects=True) as response:
123 | for k, v in response.cookies.items():
124 | if k.startswith("download_warning"):
125 | url = url + "&confirm=" + v
126 |
127 | # Then download the content of the file
128 | with session.get(url, stream=True, verify=True) as response:
129 | makedir(output_path)
130 | path = os.path.join(output_path, output_file_name)
131 | total_size = int(response.headers.get("Content-length", 0))
132 | with open(path, "wb") as file:
133 | from tqdm import tqdm
134 |
135 | with tqdm(total=total_size) as progress_bar:
136 | for block in response.iter_content(
137 | chunk_size=io.DEFAULT_BUFFER_SIZE
138 | ):
139 | file.write(block)
140 | progress_bar.update(len(block))
141 |
142 |
143 | def _get_google_drive_file_id(url: str) -> Optional[str]:
144 | parts = urlparse(url)
145 |
146 | if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
147 | return None
148 |
149 | match = re.match(r"/file/d/(?P[^/]*)", parts.path)
150 | if match is None:
151 | return None
152 |
153 | return match.group("id")
154 |
155 |
156 | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
157 | with open(filename, "wb") as fh:
158 | with urllib.request.urlopen(
159 | urllib.request.Request(url, headers={"User-Agent": "vissl"})
160 | ) as response:
161 | with tqdm(total=response.length) as pbar:
162 | for chunk in iter(lambda: response.read(chunk_size), ""):
163 | if not chunk:
164 | break
165 | pbar.update(chunk_size)
166 | fh.write(chunk)
167 |
168 |
169 | def download_url(
170 | url: str,
171 | root: str,
172 | filename: Optional[str] = None,
173 | md5: Optional[str] = None,
174 | ) -> None:
175 | """Download a file from a url and place it in root.
176 | Args:
177 | url (str): URL to download file from
178 | root (str): Directory to place downloaded file in
179 | filename (str, optional): Name to save the file under.
180 | If None, use the basename of the URL.
181 | md5 (str, optional): MD5 checksum of the download. If None, do not check
182 | """
183 | root = os.path.expanduser(root)
184 | if not filename:
185 | filename = os.path.basename(url)
186 | fpath = os.path.join(root, filename)
187 |
188 | makedir(root)
189 |
190 | # check if file is already present locally
191 | if check_integrity(fpath, md5):
192 | print("Using downloaded and verified file: " + fpath)
193 | return
194 |
195 | # expand redirect chain if needed
196 | url = get_redirected_url(url)
197 |
198 | # check if file is located on Google Drive
199 | file_id = _get_google_drive_file_id(url)
200 | if file_id is not None:
201 | return download_file_from_google_drive(file_id, root, filename, md5)
202 |
203 | # download the file
204 | try:
205 | print("Downloading " + url + " to " + fpath)
206 | _urlretrieve(url, fpath)
207 | except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
208 | if url[:5] == "https":
209 | url = url.replace("https:", "http:")
210 | print(
211 | "Failed download. Trying https -> http instead."
212 | " Downloading " + url + " to " + fpath
213 | )
214 | _urlretrieve(url, fpath)
215 | else:
216 | raise e
217 |
218 | # check integrity of downloaded file
219 | if not check_integrity(fpath, md5):
220 | raise RuntimeError("File not found or corrupted.")
221 |
222 |
223 | def download_and_extract_archive(
224 | url: str,
225 | download_root: str,
226 | extract_root: Optional[str] = None,
227 | filename: Optional[str] = None,
228 | md5: Optional[str] = None,
229 | remove_finished: bool = False,
230 | ) -> None:
231 | download_root = os.path.expanduser(download_root)
232 | if extract_root is None:
233 | extract_root = download_root
234 | if not filename:
235 | filename = os.path.basename(url)
236 |
237 | download_url(url, download_root, filename, md5)
238 |
239 | archive = os.path.join(download_root, filename)
240 | print("Extracting {} to {}".format(archive, extract_root))
241 | extract_archive(archive, extract_root, remove_finished)
242 |
243 |
244 | def cache_url(url: str, cache_dir: str) -> str:
245 | """
246 | This implementation downloads the remote resource and caches it locally.
247 | The resource will only be downloaded if not previously requested.
248 | """
249 | parsed_url = urlparse(url)
250 | dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
251 | makedir(dirname)
252 | filename = url.split("/")[-1]
253 | cached = os.path.join(dirname, filename)
254 | with file_lock(cached):
255 | if not os.path.isfile(cached):
256 | logging.info(f"Downloading {url} to {cached} ...")
257 | cached = download(url, dirname, filename=filename)
258 | logging.info(f"URL {url} cached in {cached}")
259 | return cached
260 |
261 |
262 | # TODO (prigoyal): convert this into RAII-style API
263 | def create_file_symlink(file1, file2):
264 | """
265 | Simply create the symlinks for a given file1 to file2.
266 | Useful during model checkpointing to symlinks to the
267 | latest successful checkpoint.
268 | """
269 | try:
270 | if g_pathmgr.exists(file2):
271 | g_pathmgr.rm(file2)
272 | g_pathmgr.symlink(file1, file2)
273 | except Exception as e:
274 | logging.info(f"Could NOT create symlink. Error: {e}")
275 |
276 |
277 | def save_file(data, filename, append_to_json=True, verbose=True):
278 | """
279 | Common i/o utility to handle saving data to various file formats.
280 | Supported:
281 | .pkl, .pickle, .npy, .json
282 | Specifically for .json, users have the option to either append (default)
283 | or rewrite by passing in Boolean value to append_to_json.
284 | """
285 | if verbose:
286 | logging.info(f"Saving data to file: {filename}")
287 | file_ext = os.path.splitext(filename)[1]
288 | if file_ext in [".pkl", ".pickle"]:
289 | with g_pathmgr.open(filename, "wb") as fopen:
290 | pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
291 | elif file_ext == ".npy":
292 | with g_pathmgr.open(filename, "wb") as fopen:
293 | np.save(fopen, data)
294 | elif file_ext == ".json":
295 | if append_to_json:
296 | with g_pathmgr.open(filename, "a") as fopen:
297 | fopen.write(json.dumps(data, sort_keys=True) + "\n")
298 | fopen.flush()
299 | else:
300 | with g_pathmgr.open(filename, "w") as fopen:
301 | fopen.write(json.dumps(data, sort_keys=True) + "\n")
302 | fopen.flush()
303 | elif file_ext == ".yaml":
304 | with g_pathmgr.open(filename, "w") as fopen:
305 | dump = yaml.dump(data)
306 | fopen.write(dump)
307 | fopen.flush()
308 | else:
309 | raise Exception(f"Saving {file_ext} is not supported yet")
310 |
311 | if verbose:
312 | logging.info(f"Saved data to file: {filename}")
313 |
314 |
315 | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
316 | """
317 | Common i/o utility to handle loading data from various file formats.
318 | Supported:
319 | .pkl, .pickle, .npy, .json
320 | For the npy files, we support reading the files in mmap_mode.
321 | If the mmap_mode of reading is not successful, we load data without the
322 | mmap_mode.
323 | """
324 | if verbose:
325 | logging.info(f"Loading data from file: {filename}")
326 |
327 | file_ext = os.path.splitext(filename)[1]
328 | if file_ext == ".txt":
329 | with g_pathmgr.open(filename, "r") as fopen:
330 | data = fopen.readlines()
331 | elif file_ext in [".pkl", ".pickle"]:
332 | with g_pathmgr.open(filename, "rb") as fopen:
333 | data = pickle.load(fopen, encoding="latin1")
334 | elif file_ext == ".npy":
335 | if mmap_mode:
336 | try:
337 | with g_pathmgr.open(filename, "rb") as fopen:
338 | data = np.load(
339 | fopen,
340 | allow_pickle=allow_pickle,
341 | encoding="latin1",
342 | mmap_mode=mmap_mode,
343 | )
344 | except ValueError as e:
345 | logging.info(
346 | f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
347 | )
348 | data = np.load(
349 | filename,
350 | allow_pickle=allow_pickle,
351 | encoding="latin1",
352 | mmap_mode=mmap_mode,
353 | )
354 | logging.info("Successfully loaded without g_pathmgr")
355 | except Exception:
356 | logging.info("Could not mmap without g_pathmgr. Trying without mmap")
357 | with g_pathmgr.open(filename, "rb") as fopen:
358 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
359 | else:
360 | with g_pathmgr.open(filename, "rb") as fopen:
361 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
362 | elif file_ext == ".json":
363 | with g_pathmgr.open(filename, "r") as fopen:
364 | data = json.load(fopen)
365 | elif file_ext == ".yaml":
366 | with g_pathmgr.open(filename, "r") as fopen:
367 | data = yaml.load(fopen, Loader=yaml.FullLoader)
368 | elif file_ext == ".csv":
369 | with g_pathmgr.open(filename, "r") as fopen:
370 | data = pd.read_csv(fopen)
371 | else:
372 | raise Exception(f"Reading from {file_ext} is not supported yet")
373 | return data
374 |
375 |
376 | def abspath(resource_path: str):
377 | """
378 | Make a path absolute, but take into account prefixes like
379 | "http://" or "manifold://"
380 | """
381 | regex = re.compile(r"^\w+://")
382 | if regex.match(resource_path) is None:
383 | return os.path.abspath(resource_path)
384 | else:
385 | return resource_path
386 |
387 |
388 | def makedir(dir_path):
389 | """
390 | Create the directory if it does not exist.
391 | """
392 | is_success = False
393 | try:
394 | if not g_pathmgr.exists(dir_path):
395 | g_pathmgr.mkdirs(dir_path)
396 | is_success = True
397 | except BaseException:
398 | logging.info(f"Error creating directory: {dir_path}")
399 | return is_success
400 |
401 |
402 | def is_url(input_url):
403 | """
404 | Check if an input string is a url. look for http(s):// and ignoring the case
405 | """
406 | is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
407 | return is_url
408 |
409 |
410 | def download_and_untar(url):
411 | cached_file = download_cached_file(
412 | url, check_hash=False, progress=True
413 | )
414 | # get path to untarred directory
415 | untarred_dir = os.path.basename(url).split(".")[0]
416 | parent_dir = os.path.dirname(cached_file)
417 |
418 | full_dir = os.path.join(parent_dir, untarred_dir)
419 |
420 | if not os.path.exists(full_dir):
421 | with tarfile.open(cached_file) as tar:
422 | tar.extractall(parent_dir)
423 |
424 | return full_dir
425 |
426 | def cleanup_dir(dir):
427 | """
428 | Utility for deleting a directory. Useful for cleaning the storage space
429 | that contains various training artifacts like checkpoints, data etc.
430 | """
431 | if os.path.exists(dir):
432 | logging.info(f"Deleting directory: {dir}")
433 | shutil.rmtree(dir)
434 | logging.info(f"Deleted contents of directory: {dir}")
435 |
436 |
437 | def get_file_size(filename):
438 | """
439 | Given a file, get the size of file in MB
440 | """
441 | size_in_mb = os.path.getsize(filename) / float(1024**2)
442 | return size_in_mb
443 |
444 | def is_serializable(value):
445 | """
446 | This function checks if the provided value can be serialized into a JSON string.
447 | """
448 | try:
449 | json.dumps(value)
450 | return True
451 | except (TypeError, OverflowError):
452 | return False
453 |
454 | def is_convertible_to_int(value):
455 | return bool(re.match(r'^-?\d+$', str(value)))
--------------------------------------------------------------------------------
/models/lavis/configs/default.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | env:
7 | # For default users
8 | # cache_root: "cache"
9 | # For internal use with persistent storage
10 | cache_root: "/export/home/.cache/lavis"
--------------------------------------------------------------------------------
/models/lavis/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from ..common.registry import registry
9 | from .base_model import BaseModel
10 | from .blip2_models.blip2 import Blip2Base
11 | from .blip2_models.blip2_qformer import Blip2Qformer
12 | from ..processors.base_processor import BaseProcessor
13 | from omegaconf import OmegaConf
14 | import logging
15 | import torch
16 |
17 | __all__ = ["registry", "BaseModel", "Blip2Base", "Blip2Qformer", "BaseProcessor"]
18 |
19 |
20 | # __all__ = [
21 | # "load_model",
22 | # "AlbefClassification",
23 | # "AlbefFeatureExtractor",
24 | # "AlbefNLVR",
25 | # "AlbefVQA",
26 | # "AlbefPretrain",
27 | # "AlbefRetrieval",
28 | # "AlproQA",
29 | # "AlproRetrieval",
30 | # "BaseModel",
31 | # "BlipBase",
32 | # "BlipFeatureExtractor",
33 | # "BlipCaption",
34 | # "BlipClassification",
35 | # "BlipITM",
36 | # "BlipNLVR",
37 | # "BlipPretrain",
38 | # "BlipRetrieval",
39 | # "BlipVQA",
40 | # "Blip2Qformer",
41 | # "Blip2Base",
42 | # "Blip2ITM",
43 | # "Blip2OPT",
44 | # "Blip2T5",
45 | # "PNPVQA",
46 | # "Img2PromptVQA",
47 | # "PNPUnifiedQAv2FiD",
48 | # "CLIP",
49 | # "VisionTransformerEncoder",
50 | # "XBertLMHeadDecoder",
51 | # "GPTDialogue",
52 | # ]
53 |
54 |
55 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
56 | """
57 | Load supported models.
58 |
59 | To list all available models and types in registry:
60 | >>> from lavis.models import model_zoo
61 | >>> print(model_zoo)
62 |
63 | Args:
64 | name (str): name of the model.
65 | model_type (str): type of the model.
66 | is_eval (bool): whether the model is in eval mode. Default: False.
67 | device (str): device to use. Default: "cpu".
68 | checkpoint (str): path or to checkpoint. Default: None.
69 | Note that expecting the checkpoint to have the same keys in state_dict as the model.
70 |
71 | Returns:
72 | model (torch.nn.Module): model.
73 | """
74 |
75 | model = registry.get_model_class(name).from_pretrained(model_type=model_type, force_download=True)
76 |
77 | if checkpoint is not None:
78 | model.load_checkpoint(checkpoint)
79 |
80 | if is_eval:
81 | model.eval()
82 |
83 | if device == "cpu":
84 | model = model.float()
85 |
86 | return model.to(device)
87 |
88 |
89 | def load_preprocess(config):
90 | """
91 | Load preprocessor configs and construct preprocessors.
92 |
93 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
94 |
95 | Args:
96 | config (dict): preprocessor configs.
97 |
98 | Returns:
99 | vis_processors (dict): preprocessors for visual inputs.
100 | txt_processors (dict): preprocessors for text inputs.
101 |
102 | Key is "train" or "eval" for processors used in training and evaluation respectively.
103 | """
104 |
105 | def _build_proc_from_cfg(cfg):
106 | return (
107 | registry.get_processor_class(cfg.name).from_config(cfg)
108 | if cfg is not None
109 | else BaseProcessor()
110 | )
111 |
112 | vis_processors = dict()
113 | txt_processors = dict()
114 |
115 | vis_proc_cfg = config.get("vis_processor")
116 | txt_proc_cfg = config.get("text_processor")
117 |
118 | if vis_proc_cfg is not None:
119 | vis_train_cfg = vis_proc_cfg.get("train")
120 | vis_eval_cfg = vis_proc_cfg.get("eval")
121 | else:
122 | vis_train_cfg = None
123 | vis_eval_cfg = None
124 |
125 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
126 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
127 |
128 | if txt_proc_cfg is not None:
129 | txt_train_cfg = txt_proc_cfg.get("train")
130 | txt_eval_cfg = txt_proc_cfg.get("eval")
131 | else:
132 | txt_train_cfg = None
133 | txt_eval_cfg = None
134 |
135 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
136 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
137 |
138 | return vis_processors, txt_processors
139 |
140 |
141 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
142 | """
143 | Load model and its related preprocessors.
144 |
145 | List all available models and types in registry:
146 | >>> from lavis.models import model_zoo
147 | >>> print(model_zoo)
148 |
149 | Args:
150 | name (str): name of the model.
151 | model_type (str): type of the model.
152 | is_eval (bool): whether the model is in eval mode. Default: False.
153 | device (str): device to use. Default: "cpu".
154 |
155 | Returns:
156 | model (torch.nn.Module): model.
157 | vis_processors (dict): preprocessors for visual inputs.
158 | txt_processors (dict): preprocessors for text inputs.
159 | """
160 | model_cls = registry.get_model_class(name)
161 |
162 | # load model
163 | model = model_cls.from_pretrained(model_type=model_type)
164 |
165 | if is_eval:
166 | model.eval()
167 |
168 | # load preprocess
169 | cfg = OmegaConf.load(model_cls.default_config_path(model_type))
170 | if cfg is not None:
171 | preprocess_cfg = cfg.preprocess
172 | vis_processors, txt_processors = load_preprocess(preprocess_cfg)
173 | else:
174 | vis_processors, txt_processors = None, None
175 | logging.info(
176 | f"""No default preprocess for model {name} ({model_type}).
177 | This can happen if the model is not finetuned on downstream datasets,
178 | or it is not intended for direct use without finetuning.
179 | """
180 | )
181 |
182 | if device == "cpu" or device == torch.device("cpu"):
183 | model = model.float()
184 |
185 | return model.to(device)
186 |
187 |
188 | class ModelZoo:
189 | """
190 | A utility class to create string representation of available model architectures and types.
191 |
192 | >>> from lavis.models import model_zoo
193 | >>> # list all available models
194 | >>> print(model_zoo)
195 | >>> # show total number of models
196 | >>> print(len(model_zoo))
197 | """
198 |
199 | def __init__(self) -> None:
200 | self.model_zoo = {
201 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
202 | for k, v in registry.mapping["model_name_mapping"].items()
203 | }
204 |
205 | def __str__(self) -> str:
206 | return (
207 | "=" * 50
208 | + "\n"
209 | + f"{'Architectures':<30} {'Types'}\n"
210 | + "=" * 50
211 | + "\n"
212 | + "\n".join(
213 | [
214 | f"{name:<30} {', '.join(types)}"
215 | for name, types in self.model_zoo.items()
216 | ]
217 | )
218 | )
219 |
220 | def __iter__(self):
221 | return iter(self.model_zoo.items())
222 |
223 | def __len__(self):
224 | return sum([len(v) for v in self.model_zoo.values()])
225 |
226 |
227 | model_zoo = ModelZoo()
228 |
--------------------------------------------------------------------------------
/models/lavis/models/base_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from ..common.dist_utils import download_cached_file, is_dist_avail_and_initialized
9 | from ..common.utils import get_abs_path, is_url
10 | from omegaconf import OmegaConf
11 | import torch.nn as nn
12 | import numpy as np
13 | import logging
14 | import torch
15 | import os
16 |
17 |
18 | class BaseModel(nn.Module):
19 | """Base class for models."""
20 |
21 | def __init__(self):
22 | super().__init__()
23 |
24 | @property
25 | def device(self):
26 | return list(self.parameters())[0].device
27 |
28 | def load_checkpoint(self, url_or_filename):
29 | """
30 | Load from a finetuned checkpoint.
31 |
32 | This should expect no mismatch in the model keys and the checkpoint keys.
33 | """
34 |
35 | if is_url(url_or_filename):
36 | cached_file = download_cached_file(
37 | url_or_filename, check_hash=False, progress=True
38 | )
39 | checkpoint = torch.load(cached_file, map_location="cpu")
40 | elif os.path.isfile(url_or_filename):
41 | checkpoint = torch.load(url_or_filename, map_location="cpu")
42 | else:
43 | raise RuntimeError("checkpoint url or path is invalid")
44 |
45 | if "model" in checkpoint.keys():
46 | state_dict = checkpoint["model"]
47 | else:
48 | state_dict = checkpoint
49 |
50 | msg = self.load_state_dict(state_dict, strict=False)
51 |
52 | logging.info("Missing keys {}".format(msg.missing_keys))
53 | logging.info("load checkpoint from %s" % url_or_filename)
54 |
55 | return msg
56 |
57 | @classmethod
58 | def from_pretrained(cls, model_type):
59 | """
60 | Build a pretrained model from default configuration file, specified by model_type.
61 |
62 | Args:
63 | - model_type (str): model type, specifying architecture and checkpoints.
64 |
65 | Returns:
66 | - model (nn.Module): pretrained or finetuned model, depending on the configuration.
67 | """
68 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
69 | model = cls.from_config(model_cfg)
70 |
71 | return model
72 |
73 | @classmethod
74 | def default_config_path(cls, model_type):
75 | assert (
76 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
77 | ), "Unknown model type {}".format(model_type)
78 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
79 |
80 | def load_checkpoint_from_config(self, cfg, **kwargs):
81 | """
82 | Load checkpoint as specified in the config file.
83 |
84 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
85 | When loading the pretrained model, each task-specific architecture may define their
86 | own load_from_pretrained() method.
87 | """
88 | load_finetuned = cfg.get("load_finetuned", True)
89 | if load_finetuned:
90 | finetune_path = cfg.get("finetuned", None)
91 | assert (
92 | finetune_path is not None
93 | ), "Found load_finetuned is True, but finetune_path is None."
94 | self.load_checkpoint(url_or_filename=finetune_path)
95 | else:
96 | load_pretrained = cfg.get("load_pretrained", True)
97 | if load_pretrained:
98 | # load pre-trained weights
99 | pretrain_path = cfg.get("pretrained", None)
100 | assert "Found load_finetuned is False, but pretrain_path is None."
101 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
102 |
103 | def before_training(self, **kwargs):
104 | pass
105 |
106 | def get_optimizer_params(self, weight_decay, lr_scale=1):
107 | p_wd, p_non_wd = [], []
108 | for n, p in self.named_parameters():
109 | if not p.requires_grad:
110 | continue # frozen weights
111 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
112 | p_non_wd.append(p)
113 | else:
114 | p_wd.append(p)
115 | optim_params = [
116 | {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale},
117 | {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale},
118 | ]
119 | return optim_params
120 |
121 | def before_evaluation(self, **kwargs):
122 | pass
123 |
124 | def show_n_params(self, return_str=True):
125 | tot = 0
126 | for p in self.parameters():
127 | w = 1
128 | for x in p.shape:
129 | w *= x
130 | tot += w
131 | if return_str:
132 | if tot >= 1e6:
133 | return "{:.1f}M".format(tot / 1e6)
134 | else:
135 | return "{:.1f}K".format(tot / 1e3)
136 | else:
137 | return tot
138 |
139 |
140 | class BaseEncoder(nn.Module):
141 | """
142 | Base class for primitive encoders, such as ViT, TimeSformer, etc.
143 | """
144 |
145 | def __init__(self):
146 | super().__init__()
147 |
148 | def forward_features(self, samples, **kwargs):
149 | raise NotImplementedError
150 |
151 | @property
152 | def device(self):
153 | return list(self.parameters())[0].device
154 |
155 |
156 | class SharedQueueMixin:
157 | @torch.no_grad()
158 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
159 | # gather keys before updating queue
160 | image_feats = concat_all_gather(image_feat)
161 | text_feats = concat_all_gather(text_feat)
162 |
163 | batch_size = image_feats.shape[0]
164 |
165 | ptr = int(self.queue_ptr)
166 | assert self.queue_size % batch_size == 0 # for simplicity
167 |
168 | # replace the keys at ptr (dequeue and enqueue)
169 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
170 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
171 |
172 | if idxs is not None:
173 | idxs = concat_all_gather(idxs)
174 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
175 |
176 | ptr = (ptr + batch_size) % self.queue_size # move pointer
177 | self.queue_ptr[0] = ptr
178 |
179 |
180 | class MomentumDistilationMixin:
181 | @torch.no_grad()
182 | def copy_params(self):
183 | for model_pair in self.model_pairs:
184 | for param, param_m in zip(
185 | model_pair[0].parameters(), model_pair[1].parameters()
186 | ):
187 | param_m.data.copy_(param.data) # initialize
188 | param_m.requires_grad = False # not update by gradient
189 |
190 | @torch.no_grad()
191 | def _momentum_update(self):
192 | for model_pair in self.model_pairs:
193 | for param, param_m in zip(
194 | model_pair[0].parameters(), model_pair[1].parameters()
195 | ):
196 | param_m.data = param_m.data * self.momentum + param.data * (
197 | 1.0 - self.momentum
198 | )
199 |
200 |
201 | class GatherLayer(torch.autograd.Function):
202 | """
203 | Gather tensors from all workers with support for backward propagation:
204 | This implementation does not cut the gradients as torch.distributed.all_gather does.
205 | """
206 |
207 | @staticmethod
208 | def forward(ctx, x):
209 | output = [
210 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
211 | ]
212 | torch.distributed.all_gather(output, x)
213 | return tuple(output)
214 |
215 | @staticmethod
216 | def backward(ctx, *grads):
217 | all_gradients = torch.stack(grads)
218 | torch.distributed.all_reduce(all_gradients)
219 | return all_gradients[torch.distributed.get_rank()]
220 |
221 |
222 | def all_gather_with_grad(tensors):
223 | """
224 | Performs all_gather operation on the provided tensors.
225 | Graph remains connected for backward grad computation.
226 | """
227 | # Queue the gathered tensors
228 | world_size = torch.distributed.get_world_size()
229 | # There is no need for reduction in the single-proc case
230 | if world_size == 1:
231 | return tensors
232 |
233 | # tensor_all = GatherLayer.apply(tensors)
234 | tensor_all = GatherLayer.apply(tensors)
235 |
236 | return torch.cat(tensor_all, dim=0)
237 |
238 |
239 | @torch.no_grad()
240 | def concat_all_gather(tensor):
241 | """
242 | Performs all_gather operation on the provided tensors.
243 | *** Warning ***: torch.distributed.all_gather has no gradient.
244 | """
245 | # if use distributed training
246 | if not is_dist_avail_and_initialized():
247 | return tensor
248 |
249 | tensors_gather = [
250 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
251 | ]
252 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
253 |
254 | output = torch.cat(tensors_gather, dim=0)
255 | return output
256 |
257 |
258 | def tile(x, dim, n_tile):
259 | init_dim = x.size(dim)
260 | repeat_idx = [1] * x.dim()
261 | repeat_idx[dim] = n_tile
262 | x = x.repeat(*(repeat_idx))
263 | order_index = torch.LongTensor(
264 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
265 | )
266 | return torch.index_select(x, dim, order_index.to(x.device))
267 |
--------------------------------------------------------------------------------
/models/lavis/models/blip2_models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/models/lavis/models/blip2_models/__init__.py
--------------------------------------------------------------------------------
/models/lavis/models/blip2_models/blip2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | from ..blip2_models.Qformer import BertConfig, BertLMHeadModel
8 | from ...common.dist_utils import download_cached_file
9 | import models.lavis.common.dist_utils as dist_utils
10 | from ...models.clip_vit import create_clip_vit_L
11 | from ...common.logger import MetricLogger
12 | from ..eva_vit import create_eva_vit_g
13 | from transformers import BertTokenizer
14 | from ..base_model import BaseModel
15 | from ...common.utils import is_url
16 | import torch.distributed as dist
17 | import torch.nn.functional as F
18 | import torch.nn as nn
19 | import contextlib
20 | import datetime
21 | import logging
22 | import torch
23 | import time
24 | import os
25 |
26 | class Blip2Base(BaseModel):
27 | @classmethod
28 | def init_tokenizer(cls):
29 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
30 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
31 | return tokenizer
32 |
33 | def maybe_autocast(self, dtype=torch.float16):
34 | # if on cpu, don't use autocast
35 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
36 | enable_autocast = self.device != torch.device("cpu")
37 |
38 | if enable_autocast:
39 | return torch.cuda.amp.autocast(dtype=dtype)
40 | else:
41 | return contextlib.nullcontext()
42 |
43 | @classmethod
44 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
45 | encoder_config = BertConfig.from_pretrained("bert-base-uncased")
46 | encoder_config.encoder_width = vision_width
47 | # insert cross-attention layer every other block
48 | encoder_config.add_cross_attention = True
49 | encoder_config.cross_attention_freq = cross_attention_freq
50 | encoder_config.query_length = num_query_token
51 | Qformer = BertLMHeadModel.from_pretrained(
52 | "bert-base-uncased", config=encoder_config
53 | )
54 | query_tokens = nn.Parameter(
55 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
56 | )
57 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
58 | return Qformer, query_tokens
59 |
60 | @classmethod
61 | def init_vision_encoder(
62 | cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
63 | ):
64 | assert model_name in [
65 | "eva_clip_g",
66 | "clip_L",
67 | ], "vit model must be eva_clip_g or clip_L"
68 | if model_name == "eva_clip_g":
69 | visual_encoder = create_eva_vit_g(
70 | img_size, drop_path_rate, use_grad_checkpoint, precision
71 | )
72 | elif model_name == "clip_L":
73 | visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
74 | ln_vision = LayerNorm(visual_encoder.num_features)
75 | return visual_encoder, ln_vision
76 |
77 | def load_from_pretrained(self, url_or_filename):
78 | if is_url(url_or_filename):
79 | cached_file = download_cached_file(
80 | url_or_filename, check_hash=False, progress=True
81 | )
82 | checkpoint = torch.load(cached_file, map_location="cpu")
83 | elif os.path.isfile(url_or_filename):
84 | checkpoint = torch.load(url_or_filename, map_location="cpu")
85 | else:
86 | raise RuntimeError("checkpoint url or path is invalid")
87 |
88 | state_dict = checkpoint["model"]
89 |
90 | msg = self.load_state_dict(state_dict, strict=False)
91 |
92 | # logging.info("Missing keys {}".format(msg.missing_keys))
93 | logging.info("load checkpoint from %s" % url_or_filename)
94 |
95 | return msg
96 |
97 |
98 | def disabled_train(self, mode=True):
99 | """Overwrite model.train with this function to make sure train/eval mode
100 | does not change anymore."""
101 | return self
102 |
103 |
104 | class LayerNorm(nn.LayerNorm):
105 | """Subclass torch's LayerNorm to handle fp16."""
106 |
107 | def forward(self, x: torch.Tensor):
108 | orig_type = x.dtype
109 | ret = super().forward(x.type(torch.float32))
110 | return ret.type(orig_type)
111 |
112 |
113 | def compute_sim_matrix(model, data_loader, **kwargs):
114 | k_test = kwargs.pop("k_test")
115 |
116 | metric_logger = MetricLogger(delimiter=" ")
117 | header = "Evaluation:"
118 |
119 | logging.info("Computing features for evaluation...")
120 | start_time = time.time()
121 |
122 | texts = data_loader.dataset.text
123 | num_text = len(texts)
124 | text_bs = 256
125 | text_ids = []
126 | text_embeds = []
127 | text_atts = []
128 | for i in range(0, num_text, text_bs):
129 | text = texts[i : min(num_text, i + text_bs)]
130 | text_input = model.tokenizer(
131 | text,
132 | padding="max_length",
133 | truncation=True,
134 | max_length=35,
135 | return_tensors="pt",
136 | ).to(model.device)
137 | text_feat = model.forward_text(text_input)
138 | text_embed = F.normalize(model.text_proj(text_feat))
139 | text_embeds.append(text_embed)
140 | text_ids.append(text_input.input_ids)
141 | text_atts.append(text_input.attention_mask)
142 |
143 | text_embeds = torch.cat(text_embeds, dim=0)
144 | text_ids = torch.cat(text_ids, dim=0)
145 | text_atts = torch.cat(text_atts, dim=0)
146 |
147 | vit_feats = []
148 | image_embeds = []
149 | for samples in data_loader:
150 | image = samples["image"]
151 |
152 | image = image.to(model.device)
153 | image_feat, vit_feat = model.forward_image(image)
154 | image_embed = model.vision_proj(image_feat)
155 | image_embed = F.normalize(image_embed, dim=-1)
156 |
157 | vit_feats.append(vit_feat.cpu())
158 | image_embeds.append(image_embed)
159 |
160 | vit_feats = torch.cat(vit_feats, dim=0)
161 | image_embeds = torch.cat(image_embeds, dim=0)
162 |
163 | sims_matrix = []
164 | for image_embed in image_embeds:
165 | sim_q2t = image_embed @ text_embeds.t()
166 | sim_i2t, _ = sim_q2t.max(0)
167 | sims_matrix.append(sim_i2t)
168 | sims_matrix = torch.stack(sims_matrix, dim=0)
169 |
170 | score_matrix_i2t = torch.full(
171 | (len(data_loader.dataset.image), len(texts)), -100.0
172 | ).to(model.device)
173 |
174 | num_tasks = dist_utils.get_world_size()
175 | rank = dist_utils.get_rank()
176 | step = sims_matrix.size(0) // num_tasks + 1
177 | start = rank * step
178 | end = min(sims_matrix.size(0), start + step)
179 |
180 | for i, sims in enumerate(
181 | metric_logger.log_every(sims_matrix[start:end], 50, header)
182 | ):
183 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
184 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
185 | score = model.compute_itm(
186 | image_inputs=image_inputs,
187 | text_ids=text_ids[topk_idx],
188 | text_atts=text_atts[topk_idx],
189 | ).float()
190 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim
191 |
192 | sims_matrix = sims_matrix.t()
193 | score_matrix_t2i = torch.full(
194 | (len(texts), len(data_loader.dataset.image)), -100.0
195 | ).to(model.device)
196 |
197 | step = sims_matrix.size(0) // num_tasks + 1
198 | start = rank * step
199 | end = min(sims_matrix.size(0), start + step)
200 |
201 | for i, sims in enumerate(
202 | metric_logger.log_every(sims_matrix[start:end], 50, header)
203 | ):
204 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
205 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
206 | score = model.compute_itm(
207 | image_inputs=image_inputs,
208 | text_ids=text_ids[start + i].repeat(k_test, 1),
209 | text_atts=text_atts[start + i].repeat(k_test, 1),
210 | ).float()
211 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim
212 |
213 | if dist_utils.is_dist_avail_and_initialized():
214 | dist.barrier()
215 | torch.distributed.all_reduce(
216 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
217 | )
218 | torch.distributed.all_reduce(
219 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
220 | )
221 |
222 | total_time = time.time() - start_time
223 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
224 | logging.info("Evaluation time {}".format(total_time_str))
225 |
226 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
227 |
--------------------------------------------------------------------------------
/models/lavis/models/blip_models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from typing import List
9 | from torch import nn
10 | import logging
11 |
12 |
13 | def tie_encoder_decoder_weights(
14 | encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str
15 | ):
16 | uninitialized_encoder_weights: List[str] = []
17 | if decoder.__class__ != encoder.__class__:
18 | logging.info(
19 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
20 | )
21 |
22 | def tie_encoder_to_decoder_recursively(
23 | decoder_pointer: nn.Module,
24 | encoder_pointer: nn.Module,
25 | module_name: str,
26 | uninitialized_encoder_weights: List[str],
27 | skip_key: str,
28 | depth=0,
29 | ):
30 | assert isinstance(decoder_pointer, nn.Module) and isinstance(
31 | encoder_pointer, nn.Module
32 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
33 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
34 | assert hasattr(encoder_pointer, "weight")
35 | encoder_pointer.weight = decoder_pointer.weight
36 | if hasattr(decoder_pointer, "bias"):
37 | assert hasattr(encoder_pointer, "bias")
38 | encoder_pointer.bias = decoder_pointer.bias
39 | print(module_name + " is tied")
40 | return
41 |
42 | encoder_modules = encoder_pointer._modules
43 | decoder_modules = decoder_pointer._modules
44 | if len(decoder_modules) > 0:
45 | assert (
46 | len(encoder_modules) > 0
47 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
48 |
49 | all_encoder_weights = set(
50 | [module_name + "/" + sub_name for sub_name in encoder_modules.keys()]
51 | )
52 | encoder_layer_pos = 0
53 | for name, module in decoder_modules.items():
54 | if name.isdigit():
55 | encoder_name = str(int(name) + encoder_layer_pos)
56 | decoder_name = name
57 | if not isinstance(
58 | decoder_modules[decoder_name],
59 | type(encoder_modules[encoder_name]),
60 | ) and len(encoder_modules) != len(decoder_modules):
61 | # this can happen if the name corresponds to the position in a list module list of layers
62 | # in this case the decoder has added a cross-attention that the encoder does not have
63 | # thus skip this step and subtract one layer pos from encoder
64 | encoder_layer_pos -= 1
65 | continue
66 | elif name not in encoder_modules:
67 | continue
68 | elif depth > 500:
69 | raise ValueError(
70 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
71 | )
72 | else:
73 | decoder_name = encoder_name = name
74 | tie_encoder_to_decoder_recursively(
75 | decoder_modules[decoder_name],
76 | encoder_modules[encoder_name],
77 | module_name + "/" + name,
78 | uninitialized_encoder_weights,
79 | skip_key,
80 | depth=depth + 1,
81 | )
82 | all_encoder_weights.remove(module_name + "/" + encoder_name)
83 |
84 | uninitialized_encoder_weights += list(all_encoder_weights)
85 |
86 | # tie weights recursively
87 | tie_encoder_to_decoder_recursively(
88 | decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key
89 | )
90 |
--------------------------------------------------------------------------------
/models/lavis/models/blip_models/blip.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 |
9 | from ...common.dist_utils import download_cached_file
10 | from ..vit import interpolate_pos_embed
11 | from transformers import BertTokenizer
12 | from ...common.utils import is_url
13 | from ..base_model import BaseModel
14 | from packaging import version
15 | import transformers
16 | import logging
17 | import torch
18 | import os
19 |
20 | class BlipBase(BaseModel):
21 | def __init__(self):
22 | super().__init__()
23 | transformers_version = version.parse(transformers.__version__)
24 | assert transformers_version < version.parse("4.27"), "BLIP models are not compatible with transformers>=4.27, run pip install transformers==4.25 to downgrade"
25 |
26 | @classmethod
27 | def init_tokenizer(cls):
28 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
29 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
30 | tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
31 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
32 | return tokenizer
33 |
34 | def load_from_pretrained(self, url_or_filename):
35 | if is_url(url_or_filename):
36 | cached_file = download_cached_file(
37 | url_or_filename, check_hash=False, progress=True
38 | )
39 | checkpoint = torch.load(cached_file, map_location="cpu")
40 | elif os.path.isfile(url_or_filename):
41 | checkpoint = torch.load(url_or_filename, map_location="cpu")
42 | else:
43 | raise RuntimeError("checkpoint url or path is invalid")
44 |
45 | state_dict = checkpoint["model"]
46 |
47 | state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed(
48 | state_dict["visual_encoder.pos_embed"], self.visual_encoder
49 | )
50 | if "visual_encoder_m.pos_embed" in self.state_dict().keys():
51 | state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed(
52 | state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m
53 | )
54 |
55 | for key in self.state_dict().keys():
56 | if key in state_dict.keys():
57 | if state_dict[key].shape != self.state_dict()[key].shape:
58 | del state_dict[key]
59 |
60 | msg = self.load_state_dict(state_dict, strict=False)
61 |
62 | logging.info("Missing keys {}".format(msg.missing_keys))
63 | logging.info("load checkpoint from %s" % url_or_filename)
64 |
65 | return msg
66 |
--------------------------------------------------------------------------------
/models/lavis/models/blip_models/blip_outputs.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from dataclasses import dataclass
9 | from typing import Optional
10 |
11 | import torch
12 | from transformers.modeling_outputs import (
13 | ModelOutput,
14 | BaseModelOutputWithPoolingAndCrossAttentions,
15 | CausalLMOutputWithCrossAttentions,
16 | )
17 |
18 |
19 | @dataclass
20 | class BlipSimilarity(ModelOutput):
21 | sim_i2t: torch.FloatTensor = None
22 | sim_t2i: torch.FloatTensor = None
23 |
24 | sim_i2t_m: Optional[torch.FloatTensor] = None
25 | sim_t2i_m: Optional[torch.FloatTensor] = None
26 |
27 | sim_i2t_targets: Optional[torch.FloatTensor] = None
28 | sim_t2i_targets: Optional[torch.FloatTensor] = None
29 |
30 |
31 | @dataclass
32 | class BlipIntermediateOutput(ModelOutput):
33 | """
34 | Data class for intermediate outputs of BLIP models.
35 |
36 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38 |
39 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41 |
42 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44 |
45 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46 | decoder_labels (torch.LongTensor): labels for the captioning loss.
47 |
48 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50 |
51 | """
52 |
53 | # uni-modal features
54 | image_embeds: torch.FloatTensor = None
55 | text_embeds: Optional[torch.FloatTensor] = None
56 |
57 | image_embeds_m: Optional[torch.FloatTensor] = None
58 | text_embeds_m: Optional[torch.FloatTensor] = None
59 |
60 | # intermediate outputs of multimodal encoder
61 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63 |
64 | itm_logits: Optional[torch.FloatTensor] = None
65 | itm_labels: Optional[torch.LongTensor] = None
66 |
67 | # intermediate outputs of multimodal decoder
68 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69 | decoder_labels: Optional[torch.LongTensor] = None
70 |
71 |
72 | @dataclass
73 | class BlipOutput(ModelOutput):
74 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75 | sims: Optional[BlipSimilarity] = None
76 |
77 | intermediate_output: BlipIntermediateOutput = None
78 |
79 | loss: Optional[torch.FloatTensor] = None
80 |
81 | loss_itc: Optional[torch.FloatTensor] = None
82 |
83 | loss_itm: Optional[torch.FloatTensor] = None
84 |
85 | loss_lm: Optional[torch.FloatTensor] = None
86 |
87 |
88 | @dataclass
89 | class BlipOutputWithLogits(BlipOutput):
90 | logits: torch.FloatTensor = None
91 | logits_m: torch.FloatTensor = None
92 |
93 |
94 | @dataclass
95 | class BlipOutputFeatures(ModelOutput):
96 | """
97 | Data class of features from BlipFeatureExtractor.
98 |
99 | Args:
100 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
101 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
102 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
103 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
104 |
105 | The first embedding or feature is for the [CLS] token.
106 |
107 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
108 | """
109 |
110 | image_embeds: Optional[torch.FloatTensor] = None
111 | image_embeds_proj: Optional[torch.FloatTensor] = None
112 |
113 | text_embeds: Optional[torch.FloatTensor] = None
114 | text_embeds_proj: Optional[torch.FloatTensor] = None
115 |
116 | multimodal_embeds: Optional[torch.FloatTensor] = None
117 |
--------------------------------------------------------------------------------
/models/lavis/models/clip_vit.py:
--------------------------------------------------------------------------------
1 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
2 | from ..common.dist_utils import download_cached_file
3 | from .eva_vit import convert_weights_to_fp16
4 | from collections import OrderedDict
5 | import torch.nn.functional as F
6 | from itertools import repeat
7 | import collections.abc
8 | from torch import nn
9 | import torch
10 | import math
11 |
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, inplanes, planes, stride=1):
16 | super().__init__()
17 |
18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu1 = nn.ReLU(inplace=True)
22 |
23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | self.relu2 = nn.ReLU(inplace=True)
26 |
27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28 |
29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31 | self.relu3 = nn.ReLU(inplace=True)
32 |
33 | self.downsample = None
34 | self.stride = stride
35 |
36 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38 | self.downsample = nn.Sequential(OrderedDict([
39 | ("-1", nn.AvgPool2d(stride)),
40 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41 | ("1", nn.BatchNorm2d(planes * self.expansion))
42 | ]))
43 |
44 | def forward(self, x: torch.Tensor):
45 | identity = x
46 |
47 | out = self.relu1(self.bn1(self.conv1(x)))
48 | out = self.relu2(self.bn2(self.conv2(out)))
49 | out = self.avgpool(out)
50 | out = self.bn3(self.conv3(out))
51 |
52 | if self.downsample is not None:
53 | identity = self.downsample(x)
54 |
55 | out += identity
56 | out = self.relu3(out)
57 | return out
58 |
59 |
60 | class AttentionPool2d(nn.Module):
61 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62 | super().__init__()
63 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64 | self.k_proj = nn.Linear(embed_dim, embed_dim)
65 | self.q_proj = nn.Linear(embed_dim, embed_dim)
66 | self.v_proj = nn.Linear(embed_dim, embed_dim)
67 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68 | self.num_heads = num_heads
69 |
70 | def forward(self, x):
71 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
72 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
73 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
74 | x, _ = F.multi_head_attention_forward(
75 | query=x, key=x, value=x,
76 | embed_dim_to_check=x.shape[-1],
77 | num_heads=self.num_heads,
78 | q_proj_weight=self.q_proj.weight,
79 | k_proj_weight=self.k_proj.weight,
80 | v_proj_weight=self.v_proj.weight,
81 | in_proj_weight=None,
82 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
83 | bias_k=None,
84 | bias_v=None,
85 | add_zero_attn=False,
86 | dropout_p=0,
87 | out_proj_weight=self.c_proj.weight,
88 | out_proj_bias=self.c_proj.bias,
89 | use_separate_proj_weight=True,
90 | training=self.training,
91 | need_weights=False
92 | )
93 |
94 | return x[0]
95 |
96 |
97 | class LayerNorm(nn.LayerNorm):
98 | """Subclass torch's LayerNorm to handle fp16."""
99 |
100 | def forward(self, x: torch.Tensor):
101 | orig_type = x.dtype
102 | ret = super().forward(x.type(torch.float32))
103 | return ret.type(orig_type)
104 |
105 |
106 | class QuickGELU(nn.Module):
107 | def forward(self, x: torch.Tensor):
108 | return x * torch.sigmoid(1.702 * x)
109 |
110 |
111 | class ResidualAttentionBlock(nn.Module):
112 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
113 | super().__init__()
114 |
115 | self.attn = nn.MultiheadAttention(d_model, n_head)
116 | self.ln_1 = LayerNorm(d_model)
117 | self.mlp = nn.Sequential(OrderedDict([
118 | ("c_fc", nn.Linear(d_model, d_model * 4)),
119 | ("gelu", QuickGELU()),
120 | ("c_proj", nn.Linear(d_model * 4, d_model))
121 | ]))
122 | self.ln_2 = LayerNorm(d_model)
123 | self.attn_mask = attn_mask
124 |
125 | if use_grad_checkpointing:
126 | self.attn = checkpoint_wrapper(self.attn)
127 | self.mlp = checkpoint_wrapper(self.mlp)
128 |
129 | def attention(self, x: torch.Tensor):
130 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
131 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
132 |
133 | def forward(self, x: torch.Tensor):
134 | x = x + self.attention(self.ln_1(x))
135 | x = x + self.mlp(self.ln_2(x))
136 | return x
137 |
138 |
139 | class Transformer(nn.Module):
140 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
141 | super().__init__()
142 | self.width = width
143 | self.layers = layers
144 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
145 |
146 | def forward(self, x: torch.Tensor):
147 | return self.resblocks(x)
148 |
149 |
150 | class VisionTransformer(nn.Module):
151 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
152 | super().__init__()
153 | self.input_resolution = input_resolution
154 | self.num_features = width
155 | self.num_heads = heads
156 | self.num_patches = (input_resolution // patch_size) ** 2
157 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
158 |
159 | scale = width ** -0.5
160 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
161 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
162 | self.ln_pre = LayerNorm(width)
163 |
164 | self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
165 |
166 | # self.ln_final = LayerNorm(width)
167 |
168 | def forward(self, x: torch.Tensor):
169 |
170 | x = self.conv1(x) # shape = [*, width, grid, grid]
171 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
172 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
173 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
174 | x = x + self.positional_embedding.to(x.dtype)
175 | x = self.ln_pre(x)
176 |
177 | x = x.permute(1, 0, 2) # NLD -> LND
178 | x = self.transformer(x)
179 | x = x.permute(1, 0, 2) # LND -> NLD
180 |
181 | # x = self.ln_final(x)
182 | return x
183 |
184 | def get_num_layer(self, var_name=""):
185 | if var_name in ("class_embedding", "positional_embedding", "conv1", "ln_pre"):
186 | return 0
187 | elif var_name.startswith("transformer.resblocks"):
188 | layer_id = int(var_name.split('.')[2])
189 | return layer_id + 1
190 | else:
191 | return len(self.transformer.resblocks)
192 |
193 |
194 | # From PyTorch internals
195 | def _ntuple(n):
196 | def parse(x):
197 | if isinstance(x, collections.abc.Iterable):
198 | return x
199 | return tuple(repeat(x, n))
200 | return parse
201 | to_2tuple = _ntuple(2)
202 |
203 | def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
204 | # Rescale the grid of position embeddings when loading from state_dict
205 | old_pos_embed = state_dict.get('positional_embedding', None)
206 |
207 | grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5)
208 | if old_pos_embed is None:
209 | return
210 | grid_size = to_2tuple(grid_size)
211 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
212 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
213 | if new_seq_len == old_pos_embed.shape[0]:
214 | return
215 |
216 | if extra_tokens:
217 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
218 | else:
219 | pos_emb_tok, pos_emb_img = None, old_pos_embed
220 |
221 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
222 |
223 | print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
224 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
225 | pos_emb_img = F.interpolate(
226 | pos_emb_img,
227 | size=grid_size,
228 | mode=interpolation,
229 | align_corners=True,
230 | )
231 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
232 | if pos_emb_tok is not None:
233 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
234 | else:
235 | new_pos_embed = pos_emb_img
236 | state_dict['positional_embedding'] = new_pos_embed
237 |
238 |
239 | def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"):
240 | model = VisionTransformer(
241 | input_resolution=img_size,
242 | patch_size=14,
243 | width=1024,
244 | layers=23,
245 | heads=16,
246 | use_grad_checkpointing=use_checkpoint,
247 | )
248 | url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
249 | cached_file = download_cached_file(
250 | url, check_hash=False, progress=True
251 | )
252 | state_dict = torch.load(cached_file, map_location="cpu")
253 | interpolate_pos_embed(model,state_dict)
254 |
255 | incompatible_keys = model.load_state_dict(state_dict, strict=False)
256 | # print(incompatible_keys)
257 |
258 | if precision == "fp16":
259 | convert_weights_to_fp16(model)
260 | return model
261 |
--------------------------------------------------------------------------------
/models/lavis/processors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from ..processors.base_processor import BaseProcessor
9 | from ..processors.blip_processors import (
10 | BlipImageTrainProcessor,
11 | Blip2ImageTrainProcessor,
12 | BlipImageEvalProcessor,
13 | BlipCaptionProcessor,
14 | )
15 |
16 | from ..common.registry import registry
17 |
18 | __all__ = [
19 | "BaseProcessor",
20 | # ALPRO
21 | "AlproVideoTrainProcessor",
22 | "AlproVideoEvalProcessor",
23 | # BLIP
24 | "BlipImageTrainProcessor",
25 | "Blip2ImageTrainProcessor",
26 | "BlipImageEvalProcessor",
27 | "BlipCaptionProcessor",
28 | "BlipInstructionProcessor",
29 | # BLIP-Diffusion
30 | "BlipDiffusionInputImageProcessor",
31 | "BlipDiffusionTargetImageProcessor",
32 | # CLIP
33 | "ClipImageTrainProcessor",
34 | # GPT
35 | "GPTVideoFeatureProcessor",
36 | "GPTDialogueProcessor",
37 | # AUDIO
38 | "BeatsAudioProcessor",
39 | # 3D
40 | "ULIPPCProcessor",
41 | ]
42 |
43 |
44 | def load_processor(name, cfg=None):
45 | """
46 | Example
47 |
48 | >>> processor = load_processor("alpro_video_train", cfg=None)
49 | """
50 | processor = registry.get_processor_class(name).from_config(cfg)
51 |
52 | return processor
53 |
--------------------------------------------------------------------------------
/models/lavis/processors/base_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from omegaconf import OmegaConf
9 |
10 |
11 | class BaseProcessor:
12 | def __init__(self):
13 | self.transform = lambda x: x
14 | return
15 |
16 | def __call__(self, item):
17 | return self.transform(item)
18 |
19 | @classmethod
20 | def from_config(cls, cfg=None):
21 | return cls()
22 |
23 | def build(self, **kwargs):
24 | cfg = OmegaConf.create(kwargs)
25 |
26 | return self.from_config(cfg)
27 |
--------------------------------------------------------------------------------
/models/lavis/processors/blip_processors.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | from torchvision.transforms.functional import InterpolationMode
8 | from ..processors.base_processor import BaseProcessor
9 | from ..processors.randaugment import RandomAugment
10 | from ..common.registry import registry
11 | from torchvision import transforms
12 | from omegaconf import OmegaConf
13 | import re
14 |
15 | class BlipImageBaseProcessor(BaseProcessor):
16 | def __init__(self, mean=None, std=None):
17 | if mean is None:
18 | mean = (0.48145466, 0.4578275, 0.40821073)
19 | if std is None:
20 | std = (0.26862954, 0.26130258, 0.27577711)
21 |
22 | self.normalize = transforms.Normalize(mean, std)
23 |
24 |
25 | @registry.register_processor("blip_caption")
26 | class BlipCaptionProcessor(BaseProcessor):
27 | def __init__(self, prompt="", max_words=50):
28 | self.prompt = prompt
29 | self.max_words = max_words
30 |
31 | def __call__(self, caption):
32 | caption = self.prompt + self.pre_caption(caption)
33 |
34 | return caption
35 |
36 | @classmethod
37 | def from_config(cls, cfg=None):
38 | if cfg is None:
39 | cfg = OmegaConf.create()
40 |
41 | prompt = cfg.get("prompt", "")
42 | max_words = cfg.get("max_words", 50)
43 |
44 | return cls(prompt=prompt, max_words=max_words)
45 |
46 | def pre_caption(self, caption):
47 | caption = re.sub(
48 | r"([.!\"()*#:;~])",
49 | " ",
50 | caption.lower(),
51 | )
52 | caption = re.sub(
53 | r"\s{2,}",
54 | " ",
55 | caption,
56 | )
57 | caption = caption.rstrip("\n")
58 | caption = caption.strip(" ")
59 |
60 | # truncate caption
61 | caption_words = caption.split(" ")
62 | if len(caption_words) > self.max_words:
63 | caption = " ".join(caption_words[: self.max_words])
64 |
65 | return caption
66 |
67 |
68 | @registry.register_processor("blip_question")
69 | class BlipQuestionProcessor(BaseProcessor):
70 | def __init__(self, max_words=50):
71 | self.max_words = max_words
72 |
73 | def __call__(self, question):
74 | return self.pre_question(question)
75 |
76 | @classmethod
77 | def from_config(cls, cfg=None):
78 | if cfg is None:
79 | cfg = OmegaConf.create()
80 |
81 | max_words = cfg.get("max_words", 50)
82 |
83 | return cls(max_words=max_words)
84 |
85 | def pre_question(self, question):
86 | question = re.sub(
87 | r"([.!\"()*#:;~])",
88 | "",
89 | question.lower(),
90 | )
91 | question = question.rstrip(" ")
92 |
93 | # truncate question
94 | question_words = question.split(" ")
95 | if len(question_words) > self.max_words:
96 | question = " ".join(question_words[: self.max_words])
97 |
98 | return question
99 |
100 |
101 | @registry.register_processor("blip_image_train")
102 | class BlipImageTrainProcessor(BlipImageBaseProcessor):
103 | def __init__(
104 | self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
105 | ):
106 | super().__init__(mean=mean, std=std)
107 |
108 | self.transform = transforms.Compose(
109 | [
110 | transforms.RandomResizedCrop(
111 | image_size,
112 | scale=(min_scale, max_scale),
113 | interpolation=InterpolationMode.BICUBIC,
114 | ),
115 | transforms.RandomHorizontalFlip(),
116 | RandomAugment(
117 | 2,
118 | 5,
119 | isPIL=True,
120 | augs=[
121 | "Identity",
122 | "AutoContrast",
123 | "Brightness",
124 | "Sharpness",
125 | "Equalize",
126 | "ShearX",
127 | "ShearY",
128 | "TranslateX",
129 | "TranslateY",
130 | "Rotate",
131 | ],
132 | ),
133 | transforms.ToTensor(),
134 | self.normalize,
135 | ]
136 | )
137 |
138 | def __call__(self, item):
139 | return self.transform(item)
140 |
141 | @classmethod
142 | def from_config(cls, cfg=None):
143 | if cfg is None:
144 | cfg = OmegaConf.create()
145 |
146 | image_size = cfg.get("image_size", 384)
147 |
148 | mean = cfg.get("mean", None)
149 | std = cfg.get("std", None)
150 |
151 | min_scale = cfg.get("min_scale", 0.5)
152 | max_scale = cfg.get("max_scale", 1.0)
153 |
154 | return cls(
155 | image_size=image_size,
156 | mean=mean,
157 | std=std,
158 | min_scale=min_scale,
159 | max_scale=max_scale,
160 | )
161 |
162 |
163 | @registry.register_processor("blip_image_eval")
164 | class BlipImageEvalProcessor(BlipImageBaseProcessor):
165 | def __init__(self, image_size=384, mean=None, std=None):
166 | super().__init__(mean=mean, std=std)
167 |
168 | self.transform = transforms.Compose(
169 | [
170 | transforms.Resize(
171 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
172 | ),
173 | transforms.ToTensor(),
174 | self.normalize,
175 | ]
176 | )
177 |
178 | def __call__(self, item):
179 | return self.transform(item)
180 |
181 | @classmethod
182 | def from_config(cls, cfg=None):
183 | if cfg is None:
184 | cfg = OmegaConf.create()
185 |
186 | image_size = cfg.get("image_size", 384)
187 |
188 | mean = cfg.get("mean", None)
189 | std = cfg.get("std", None)
190 |
191 | return cls(image_size=image_size, mean=mean, std=std)
192 |
193 |
194 | @registry.register_processor("blip2_image_train")
195 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
196 | def __init__(
197 | self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0
198 | ):
199 | super().__init__(mean=mean, std=std)
200 |
201 | self.transform = transforms.Compose(
202 | [
203 | transforms.RandomResizedCrop(
204 | image_size,
205 | scale=(min_scale, max_scale),
206 | interpolation=InterpolationMode.BICUBIC,
207 | ),
208 | transforms.RandomHorizontalFlip(),
209 | transforms.ToTensor(),
210 | self.normalize,
211 | ]
212 | )
213 |
214 | def __call__(self, item):
215 | return self.transform(item)
216 |
217 | @classmethod
218 | def from_config(cls, cfg=None):
219 | if cfg is None:
220 | cfg = OmegaConf.create()
221 |
222 | image_size = cfg.get("image_size", 364)
223 |
224 | mean = cfg.get("mean", None)
225 | std = cfg.get("std", None)
226 |
227 | min_scale = cfg.get("min_scale", 0.5)
228 | max_scale = cfg.get("max_scale", 1.0)
229 |
230 | return cls(
231 | image_size=image_size,
232 | mean=mean,
233 | std=std,
234 | min_scale=min_scale,
235 | max_scale=max_scale,
236 | )
--------------------------------------------------------------------------------
/models/lavis/processors/clip_processors.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from torchvision.transforms.functional import InterpolationMode
9 | from processors.blip_processors import BlipImageBaseProcessor
10 | from ..common.registry import registry
11 | from torchvision import transforms
12 | from omegaconf import OmegaConf
13 |
14 |
15 | def _convert_to_rgb(image):
16 | return image.convert("RGB")
17 |
18 |
19 | @registry.register_processor("clip_image_train")
20 | class ClipImageTrainProcessor(BlipImageBaseProcessor):
21 | def __init__(
22 | self, image_size=224, mean=None, std=None, min_scale=0.9, max_scale=1.0
23 | ):
24 |
25 | super().__init__(mean=mean, std=std)
26 |
27 | self.transform = transforms.Compose(
28 | [
29 | transforms.RandomResizedCrop(
30 | image_size,
31 | scale=(min_scale, max_scale),
32 | interpolation=InterpolationMode.BICUBIC,
33 | ),
34 | _convert_to_rgb,
35 | transforms.ToTensor(),
36 | self.normalize,
37 | ]
38 | )
39 |
40 | @classmethod
41 | def from_config(cls, cfg=None):
42 | if cfg is None:
43 | cfg = OmegaConf.create()
44 |
45 | image_size = cfg.get("image_size", 224)
46 |
47 | mean = cfg.get("mean", None)
48 | std = cfg.get("std", None)
49 |
50 | min_scale = cfg.get("min_scale", 0.9)
51 | max_scale = cfg.get("max_scale", 1.0)
52 |
53 | return cls(
54 | image_size=image_size,
55 | mean=mean,
56 | std=std,
57 | min_scale=min_scale,
58 | max_scale=max_scale,
59 | )
60 |
61 |
62 | @registry.register_processor("clip_image_eval")
63 | class ClipImageEvalProcessor(BlipImageBaseProcessor):
64 | def __init__(self, image_size=224, mean=None, std=None):
65 |
66 | super().__init__(mean=mean, std=std)
67 |
68 | self.transform = transforms.Compose(
69 | [
70 | transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
71 | transforms.CenterCrop(image_size),
72 | _convert_to_rgb,
73 | transforms.ToTensor(),
74 | self.normalize,
75 | ]
76 | )
77 |
78 | @classmethod
79 | def from_config(cls, cfg=None):
80 | if cfg is None:
81 | cfg = OmegaConf.create()
82 |
83 | image_size = cfg.get("image_size", 224)
84 |
85 | mean = cfg.get("mean", None)
86 | std = cfg.get("std", None)
87 |
88 | return cls(
89 | image_size=image_size,
90 | mean=mean,
91 | std=std,
92 | )
93 |
--------------------------------------------------------------------------------
/models/lavis/processors/randaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | import cv2
11 |
12 |
13 | ## aug functions
14 | def identity_func(img):
15 | return img
16 |
17 |
18 | def autocontrast_func(img, cutoff=0):
19 | """
20 | same output as PIL.ImageOps.autocontrast
21 | """
22 | n_bins = 256
23 |
24 | def tune_channel(ch):
25 | n = ch.size
26 | cut = cutoff * n // 100
27 | if cut == 0:
28 | high, low = ch.max(), ch.min()
29 | else:
30 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
31 | low = np.argwhere(np.cumsum(hist) > cut)
32 | low = 0 if low.shape[0] == 0 else low[0]
33 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
34 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
35 | if high <= low:
36 | table = np.arange(n_bins)
37 | else:
38 | scale = (n_bins - 1) / (high - low)
39 | offset = -low * scale
40 | table = np.arange(n_bins) * scale + offset
41 | table[table < 0] = 0
42 | table[table > n_bins - 1] = n_bins - 1
43 | table = table.clip(0, 255).astype(np.uint8)
44 | return table[ch]
45 |
46 | channels = [tune_channel(ch) for ch in cv2.split(img)]
47 | out = cv2.merge(channels)
48 | return out
49 |
50 |
51 | def equalize_func(img):
52 | """
53 | same output as PIL.ImageOps.equalize
54 | PIL's implementation is different from cv2.equalize
55 | """
56 | n_bins = 256
57 |
58 | def tune_channel(ch):
59 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
60 | non_zero_hist = hist[hist != 0].reshape(-1)
61 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
62 | if step == 0:
63 | return ch
64 | n = np.empty_like(hist)
65 | n[0] = step // 2
66 | n[1:] = hist[:-1]
67 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
68 | return table[ch]
69 |
70 | channels = [tune_channel(ch) for ch in cv2.split(img)]
71 | out = cv2.merge(channels)
72 | return out
73 |
74 |
75 | def rotate_func(img, degree, fill=(0, 0, 0)):
76 | """
77 | like PIL, rotate by degree, not radians
78 | """
79 | H, W = img.shape[0], img.shape[1]
80 | center = W / 2, H / 2
81 | M = cv2.getRotationMatrix2D(center, degree, 1)
82 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
83 | return out
84 |
85 |
86 | def solarize_func(img, thresh=128):
87 | """
88 | same output as PIL.ImageOps.posterize
89 | """
90 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
91 | table = table.clip(0, 255).astype(np.uint8)
92 | out = table[img]
93 | return out
94 |
95 |
96 | def color_func(img, factor):
97 | """
98 | same output as PIL.ImageEnhance.Color
99 | """
100 | ## implementation according to PIL definition, quite slow
101 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
102 | # out = blend(degenerate, img, factor)
103 | # M = (
104 | # np.eye(3) * factor
105 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
106 | # )[np.newaxis, np.newaxis, :]
107 | M = np.float32(
108 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
109 | ) * factor + np.float32([[0.114], [0.587], [0.299]])
110 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
111 | return out
112 |
113 |
114 | def contrast_func(img, factor):
115 | """
116 | same output as PIL.ImageEnhance.Contrast
117 | """
118 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
119 | table = (
120 | np.array([(el - mean) * factor + mean for el in range(256)])
121 | .clip(0, 255)
122 | .astype(np.uint8)
123 | )
124 | out = table[img]
125 | return out
126 |
127 |
128 | def brightness_func(img, factor):
129 | """
130 | same output as PIL.ImageEnhance.Contrast
131 | """
132 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
133 | out = table[img]
134 | return out
135 |
136 |
137 | def sharpness_func(img, factor):
138 | """
139 | The differences the this result and PIL are all on the 4 boundaries, the center
140 | areas are same
141 | """
142 | kernel = np.ones((3, 3), dtype=np.float32)
143 | kernel[1][1] = 5
144 | kernel /= 13
145 | degenerate = cv2.filter2D(img, -1, kernel)
146 | if factor == 0.0:
147 | out = degenerate
148 | elif factor == 1.0:
149 | out = img
150 | else:
151 | out = img.astype(np.float32)
152 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
153 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
154 | out = out.astype(np.uint8)
155 | return out
156 |
157 |
158 | def shear_x_func(img, factor, fill=(0, 0, 0)):
159 | H, W = img.shape[0], img.shape[1]
160 | M = np.float32([[1, factor, 0], [0, 1, 0]])
161 | out = cv2.warpAffine(
162 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
163 | ).astype(np.uint8)
164 | return out
165 |
166 |
167 | def translate_x_func(img, offset, fill=(0, 0, 0)):
168 | """
169 | same output as PIL.Image.transform
170 | """
171 | H, W = img.shape[0], img.shape[1]
172 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
173 | out = cv2.warpAffine(
174 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
175 | ).astype(np.uint8)
176 | return out
177 |
178 |
179 | def translate_y_func(img, offset, fill=(0, 0, 0)):
180 | """
181 | same output as PIL.Image.transform
182 | """
183 | H, W = img.shape[0], img.shape[1]
184 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
185 | out = cv2.warpAffine(
186 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
187 | ).astype(np.uint8)
188 | return out
189 |
190 |
191 | def posterize_func(img, bits):
192 | """
193 | same output as PIL.ImageOps.posterize
194 | """
195 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
196 | return out
197 |
198 |
199 | def shear_y_func(img, factor, fill=(0, 0, 0)):
200 | H, W = img.shape[0], img.shape[1]
201 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
202 | out = cv2.warpAffine(
203 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
204 | ).astype(np.uint8)
205 | return out
206 |
207 |
208 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
209 | replace = np.array(replace, dtype=np.uint8)
210 | H, W = img.shape[0], img.shape[1]
211 | rh, rw = np.random.random(2)
212 | pad_size = pad_size // 2
213 | ch, cw = int(rh * H), int(rw * W)
214 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
215 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
216 | out = img.copy()
217 | out[x1:x2, y1:y2, :] = replace
218 | return out
219 |
220 |
221 | ### level to args
222 | def enhance_level_to_args(MAX_LEVEL):
223 | def level_to_args(level):
224 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
225 |
226 | return level_to_args
227 |
228 |
229 | def shear_level_to_args(MAX_LEVEL, replace_value):
230 | def level_to_args(level):
231 | level = (level / MAX_LEVEL) * 0.3
232 | if np.random.random() > 0.5:
233 | level = -level
234 | return (level, replace_value)
235 |
236 | return level_to_args
237 |
238 |
239 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
240 | def level_to_args(level):
241 | level = (level / MAX_LEVEL) * float(translate_const)
242 | if np.random.random() > 0.5:
243 | level = -level
244 | return (level, replace_value)
245 |
246 | return level_to_args
247 |
248 |
249 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
250 | def level_to_args(level):
251 | level = int((level / MAX_LEVEL) * cutout_const)
252 | return (level, replace_value)
253 |
254 | return level_to_args
255 |
256 |
257 | def solarize_level_to_args(MAX_LEVEL):
258 | def level_to_args(level):
259 | level = int((level / MAX_LEVEL) * 256)
260 | return (level,)
261 |
262 | return level_to_args
263 |
264 |
265 | def none_level_to_args(level):
266 | return ()
267 |
268 |
269 | def posterize_level_to_args(MAX_LEVEL):
270 | def level_to_args(level):
271 | level = int((level / MAX_LEVEL) * 4)
272 | return (level,)
273 |
274 | return level_to_args
275 |
276 |
277 | def rotate_level_to_args(MAX_LEVEL, replace_value):
278 | def level_to_args(level):
279 | level = (level / MAX_LEVEL) * 30
280 | if np.random.random() < 0.5:
281 | level = -level
282 | return (level, replace_value)
283 |
284 | return level_to_args
285 |
286 |
287 | func_dict = {
288 | "Identity": identity_func,
289 | "AutoContrast": autocontrast_func,
290 | "Equalize": equalize_func,
291 | "Rotate": rotate_func,
292 | "Solarize": solarize_func,
293 | "Color": color_func,
294 | "Contrast": contrast_func,
295 | "Brightness": brightness_func,
296 | "Sharpness": sharpness_func,
297 | "ShearX": shear_x_func,
298 | "TranslateX": translate_x_func,
299 | "TranslateY": translate_y_func,
300 | "Posterize": posterize_func,
301 | "ShearY": shear_y_func,
302 | }
303 |
304 | translate_const = 10
305 | MAX_LEVEL = 10
306 | replace_value = (128, 128, 128)
307 | arg_dict = {
308 | "Identity": none_level_to_args,
309 | "AutoContrast": none_level_to_args,
310 | "Equalize": none_level_to_args,
311 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
312 | "Solarize": solarize_level_to_args(MAX_LEVEL),
313 | "Color": enhance_level_to_args(MAX_LEVEL),
314 | "Contrast": enhance_level_to_args(MAX_LEVEL),
315 | "Brightness": enhance_level_to_args(MAX_LEVEL),
316 | "Sharpness": enhance_level_to_args(MAX_LEVEL),
317 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
318 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
319 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320 | "Posterize": posterize_level_to_args(MAX_LEVEL),
321 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
322 | }
323 |
324 |
325 | class RandomAugment(object):
326 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
327 | self.N = N
328 | self.M = M
329 | self.isPIL = isPIL
330 | if augs:
331 | self.augs = augs
332 | else:
333 | self.augs = list(arg_dict.keys())
334 |
335 | def get_random_ops(self):
336 | sampled_ops = np.random.choice(self.augs, self.N)
337 | return [(op, 0.5, self.M) for op in sampled_ops]
338 |
339 | def __call__(self, img):
340 | if self.isPIL:
341 | img = np.array(img)
342 | ops = self.get_random_ops()
343 | for name, prob, level in ops:
344 | if np.random.random() > prob:
345 | continue
346 | args = arg_dict[name](level)
347 | img = func_dict[name](img, *args)
348 | return img
349 |
350 |
351 | class VideoRandomAugment(object):
352 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
353 | self.N = N
354 | self.M = M
355 | self.p = p
356 | self.tensor_in_tensor_out = tensor_in_tensor_out
357 | if augs:
358 | self.augs = augs
359 | else:
360 | self.augs = list(arg_dict.keys())
361 |
362 | def get_random_ops(self):
363 | sampled_ops = np.random.choice(self.augs, self.N, replace=False)
364 | return [(op, self.M) for op in sampled_ops]
365 |
366 | def __call__(self, frames):
367 | assert (
368 | frames.shape[-1] == 3
369 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
370 |
371 | if self.tensor_in_tensor_out:
372 | frames = frames.numpy().astype(np.uint8)
373 |
374 | num_frames = frames.shape[0]
375 |
376 | ops = num_frames * [self.get_random_ops()]
377 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
378 |
379 | frames = torch.stack(
380 | list(map(self._aug, frames, ops, apply_or_not)), dim=0
381 | ).float()
382 |
383 | return frames
384 |
385 | def _aug(self, img, ops, apply_or_not):
386 | for i, (name, level) in enumerate(ops):
387 | if not apply_or_not[i]:
388 | continue
389 | args = arg_dict[name](level)
390 | img = func_dict[name](img, *args)
391 | return torch.from_numpy(img)
392 |
393 |
394 | if __name__ == "__main__":
395 | a = RandomAugment()
396 | img = np.random.randn(32, 32, 3)
397 | a(img)
398 |
--------------------------------------------------------------------------------
/models/lavis/runners/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from .runner_base import RunnerBase
9 | from .runner_iter import RunnerIter
10 |
11 | __all__ = ["RunnerBase", "RunnerIter"]
12 |
--------------------------------------------------------------------------------
/models/lavis/runners/runner_iter.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import logging
10 | import os
11 | import time
12 |
13 | import torch
14 | import torch.distributed as dist
15 | import webdataset as wds
16 | from lavis.common.dist_utils import download_cached_file, is_main_process, main_process
17 | from lavis.common.registry import registry
18 | from lavis.common.utils import is_url
19 | from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split
20 | from lavis.runners.runner_base import RunnerBase
21 | from torch.utils.data.dataset import ChainDataset
22 |
23 |
24 | @registry.register_runner("runner_iter")
25 | class RunnerIter(RunnerBase):
26 | """
27 | Run training based on the number of iterations. This is common when
28 | the training dataset size is large. Underhood logic is similar to
29 | epoch-based training by considering every #iters_per_inner_epoch as an
30 | inner epoch.
31 |
32 | In iter-based runner, after every #iters_per_inner_epoch steps, we
33 |
34 | 1) do a validation epoch;
35 | 2) schedule the learning rate;
36 | 3) save the checkpoint.
37 |
38 | We refer every #iters_per_inner_epoch steps as an inner epoch.
39 | """
40 |
41 | def __init__(self, cfg, task, model, datasets, job_id):
42 | super().__init__(cfg, task, model, datasets, job_id)
43 |
44 | self.start_iters = 0
45 |
46 | self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
47 | assert self.max_iters > 0, "max_iters must be greater than 0."
48 |
49 | self.iters_per_inner_epoch = int(
50 | self.config.run_cfg.get("iters_per_inner_epoch", -1)
51 | )
52 | assert (
53 | self.iters_per_inner_epoch > 0
54 | ), "iters_per_inner_epoch must be greater than 0."
55 |
56 | @property
57 | def max_epoch(self):
58 | return int(self.max_iters / self.iters_per_inner_epoch)
59 |
60 | @property
61 | def cur_epoch(self):
62 | try:
63 | return self.train_loader.epoch
64 | except AttributeError:
65 | # pipeline data (e.g. LAION) is streaming, have no concept of epoch
66 | return 0
67 |
68 | def _progress(self, cur_iters):
69 | return "{}_iters={}".format(self.cur_epoch, cur_iters)
70 |
71 | def train(self):
72 | start_time = time.time()
73 | best_agg_metric = 0
74 | best_iters = 0
75 |
76 | self.log_config()
77 |
78 | # resume from checkpoint if specified
79 | if not self.evaluate_only and self.resume_ckpt_path is not None:
80 | self._load_checkpoint(self.resume_ckpt_path)
81 |
82 | for start_iters in range(
83 | self.start_iters, self.max_iters, self.iters_per_inner_epoch
84 | ):
85 | end_iters = start_iters + self.iters_per_inner_epoch
86 |
87 | # training phase
88 | if not self.evaluate_only:
89 | logging.info(
90 | "Start training, max_iters={}, in total {} inner epochs.".format(
91 | self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
92 | )
93 | )
94 | if start_iters == self.start_iters:
95 | self.task.before_training(
96 | model=self.unwrap_dist_model(self.model),
97 | dataset=self.datasets,
98 | )
99 | train_stats = self.train_iters(self.cur_epoch, start_iters)
100 | self.log_stats(split_name="train", stats=train_stats)
101 |
102 | # evaluation phase
103 | if len(self.valid_splits) > 0 and (self.evaluate_only or (end_iters//self.iters_per_inner_epoch)%self.val_freq == 0):
104 | for split_name in self.valid_splits:
105 | logging.info("Evaluating on {}.".format(split_name))
106 |
107 | val_log = self.eval_epoch(
108 | split_name=split_name, cur_epoch=self._progress(end_iters)
109 | )
110 | if val_log is not None:
111 | if is_main_process():
112 | assert (
113 | "agg_metrics" in val_log
114 | ), "No agg_metrics found in validation log."
115 |
116 | agg_metrics = val_log["agg_metrics"]
117 | if agg_metrics > best_agg_metric and split_name == "val":
118 | best_iters, best_agg_metric = end_iters, agg_metrics
119 |
120 | self._save_checkpoint(end_iters, is_best=True)
121 |
122 | val_log.update({"best_iters": best_iters})
123 | self.log_stats(val_log, split_name)
124 |
125 | else:
126 | # if no validation split is provided, we just save the checkpoint at the end of each inner epoch.
127 | if not self.evaluate_only:
128 | self._save_checkpoint(end_iters, is_best=False)
129 |
130 | if self.evaluate_only:
131 | break
132 |
133 | # save checkpoint according to save freq
134 | # if self.save_freq>0 and (end_iters//self.iters_per_inner_epoch)%self.save_freq == 0:
135 | self._save_checkpoint(end_iters, is_best=False)
136 |
137 | dist.barrier()
138 |
139 | # save last checkpoint
140 | if self.save_last and not self.evaluate_only:
141 | self._save_checkpoint(end_iters, is_best=False)
142 |
143 | # testing phase
144 | self.evaluate(cur_epoch=self.cur_epoch)
145 |
146 | total_time = time.time() - start_time
147 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
148 | logging.info("Training time {}".format(total_time_str))
149 |
150 | def train_iters(self, epoch, start_iters):
151 | # train by iterations
152 | self.model.train()
153 |
154 | return self.task.train_iters(
155 | epoch=epoch,
156 | start_iters=start_iters,
157 | iters_per_inner_epoch=self.iters_per_inner_epoch,
158 | model=self.model,
159 | data_loader=self.train_loader,
160 | optimizer=self.optimizer,
161 | scaler=self.scaler,
162 | lr_scheduler=self.lr_scheduler,
163 | cuda_enabled=self.cuda_enabled,
164 | log_freq=self.log_freq,
165 | accum_grad_iters=self.accum_grad_iters,
166 | )
167 |
168 | @main_process
169 | def _save_checkpoint(self, cur_iters, is_best=False, is_last=False):
170 | model_no_ddp = self.unwrap_dist_model(self.model)
171 | param_grad_dic = {
172 | k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
173 | }
174 |
175 | state_dict = model_no_ddp.state_dict()
176 | for k in list(state_dict.keys()):
177 | if k in param_grad_dic.keys() and not param_grad_dic[k]:
178 | # delete parameters that do not require gradient
179 | del state_dict[k]
180 |
181 | save_obj = {
182 | "model": state_dict,
183 | "optimizer": self.optimizer.state_dict(),
184 | "config": self.config.to_dict(),
185 | "scaler": self.scaler.state_dict() if self.scaler else None,
186 | "iters": cur_iters,
187 | }
188 | save_to = os.path.join(
189 | self.output_dir,
190 | "checkpoint_{}.pth".format("best" if is_best else cur_iters),
191 | )
192 | logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
193 | torch.save(save_obj, save_to)
194 |
195 | def _load_checkpoint(self, url_or_filename):
196 | """
197 | Resume from a checkpoint.
198 | """
199 | if is_url(url_or_filename):
200 | cached_file = download_cached_file(
201 | url_or_filename, check_hash=False, progress=True
202 | )
203 | checkpoint = torch.load(cached_file, map_location=self.device)
204 | elif os.path.isfile(url_or_filename):
205 | checkpoint = torch.load(url_or_filename, map_location=self.device)
206 | else:
207 | raise RuntimeError("checkpoint url or path is invalid")
208 |
209 | state_dict = checkpoint["model"]
210 | self.unwrap_dist_model(self.model).load_state_dict(state_dict)
211 |
212 | self.optimizer.load_state_dict(checkpoint["optimizer"])
213 | if self.scaler and "scaler" in checkpoint:
214 | self.scaler.load_state_dict(checkpoint["scaler"])
215 |
216 | self.start_iters = checkpoint["iters"] + 1
217 | logging.info("Resume checkpoint from {}".format(url_or_filename))
218 |
219 | @property
220 | def dataloaders(self) -> dict:
221 | """
222 | A property to get and create dataloaders by split just in need.
223 |
224 | If no train_dataset_ratio is provided, concatenate map-style datasets and
225 | chain wds.DataPipe datasets separately. Training set becomes a tuple
226 | (ConcatDataset, ChainDataset), both are optional but at least one of them is
227 | required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
228 |
229 | If train_dataset_ratio is provided, create a MultiIterLoader to sample
230 | each dataset by ratios during training.
231 |
232 | Currently do not support multiple datasets for validation and test.
233 |
234 | Returns:
235 | dict: {split_name: (tuples of) dataloader}
236 | """
237 | if self._dataloaders is None:
238 | # reoganize datasets by split and concatenate/chain if necessary
239 | dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None)
240 |
241 | if dataset_ratios is None:
242 | # concatenate map-style datasets and chain wds.DataPipe datasets separately
243 | # training set becomes a tuple (ConcatDataset, ChainDataset), both are
244 | # optional but at least one of them is required. The resultant ConcatDataset
245 | # and ChainDataset will be sampled evenly.
246 | logging.info(
247 | "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
248 | )
249 |
250 | datasets = reorg_datasets_by_split(self.datasets)
251 | self.datasets = concat_datasets(datasets)
252 | else:
253 | # create multi-loader with the provided ratios, without concatenating or chaining
254 | missing_keys = [k for k in dataset_ratios if k not in self.datasets]
255 | if len(missing_keys) > 0:
256 | raise ValueError(
257 | "Datasets with the following split names are not found: {}".format(
258 | missing_keys
259 | )
260 | )
261 |
262 | unexpected_keys = [k for k in self.datasets if k not in dataset_ratios]
263 | if len(unexpected_keys) > 0:
264 | raise ValueError(
265 | "Datasets with the following split names are not expected: {}".format(
266 | unexpected_keys
267 | )
268 | )
269 |
270 | dataset_ratios = [float(dataset_ratios[k]) for k in self.datasets]
271 | self.datasets = reorg_datasets_by_split(self.datasets)
272 | # to keep the same structure as return value of concat_datasets
273 | self.datasets = {
274 | k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
275 | }
276 |
277 | # print dataset statistics after concatenation/chaining
278 | for split_name in self.datasets:
279 | if isinstance(self.datasets[split_name], tuple) or isinstance(
280 | self.datasets[split_name], list
281 | ):
282 | # mixed wds.DataPipeline and torch.utils.data.Dataset
283 | num_records = sum(
284 | [
285 | len(d)
286 | if not type(d) in [wds.DataPipeline, ChainDataset]
287 | else 0
288 | for d in self.datasets[split_name]
289 | ]
290 | )
291 |
292 | else:
293 | try:
294 | # a single map-style dataset
295 | num_records = len(self.datasets[split_name])
296 | except TypeError:
297 | # a single wds.DataPipeline or ChainDataset
298 | num_records = -1
299 | logging.info(
300 | "Only a single wds.DataPipeline dataset, no __len__ attribute."
301 | )
302 |
303 | if num_records >= 0:
304 | logging.info(
305 | "Loaded {} records for {} split from the dataset.".format(
306 | num_records, split_name
307 | )
308 | )
309 |
310 | # create dataloaders
311 | split_names = sorted(self.datasets.keys())
312 |
313 | datasets = [self.datasets[split] for split in split_names]
314 | is_trains = [split in self.train_splits for split in split_names]
315 |
316 | batch_sizes = [
317 | self.config.run_cfg.batch_size_train
318 | if split == "train"
319 | else self.config.run_cfg.batch_size_eval
320 | for split in split_names
321 | ]
322 |
323 | collate_fns = []
324 | for dataset in datasets:
325 | if isinstance(dataset, tuple) or isinstance(dataset, list):
326 | collate_fns.append([getattr(d, "collater", None) for d in dataset])
327 | else:
328 | collate_fns.append(getattr(dataset, "collater", None))
329 |
330 | dataloaders = self.create_loaders(
331 | datasets=datasets,
332 | num_workers=self.config.run_cfg.num_workers,
333 | batch_sizes=batch_sizes,
334 | is_trains=is_trains,
335 | collate_fns=collate_fns,
336 | dataset_ratios=dataset_ratios,
337 | )
338 |
339 | self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
340 |
341 | return self._dataloaders
342 |
--------------------------------------------------------------------------------
/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | from util.misc import NestedTensor
2 | from torch import nn
3 | import torch
4 | import math
5 |
6 | class PositionEmbeddingSine(nn.Module):
7 | """
8 | This is a more standard version of the position embedding, very similar to the one
9 | used by the Attention is all you need paper, generalized to work on images.
10 | """
11 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
12 | super().__init__()
13 | self.num_pos_feats = num_pos_feats
14 | self.temperature = temperature
15 | self.normalize = normalize
16 | if scale is not None and normalize is False:
17 | raise ValueError("normalize should be True if scale is passed")
18 | if scale is None:
19 | scale = 2 * math.pi
20 | self.scale = scale
21 |
22 | def forward(self, tensor_list: NestedTensor):
23 | x = tensor_list.tensors
24 | mask = tensor_list.mask
25 | assert mask is not None
26 | not_mask = ~mask
27 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
28 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
29 | if self.normalize:
30 | eps = 1e-6
31 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
32 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
33 |
34 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
35 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
36 |
37 | pos_x = x_embed[:, :, :, None] / dim_t
38 | pos_y = y_embed[:, :, :, None] / dim_t
39 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
40 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
41 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
42 | return pos
43 |
44 | class PositionEmbeddingLearned(nn.Module):
45 | """
46 | Absolute pos embedding, learned.
47 | """
48 | def __init__(self, num_pos_feats=256):
49 | super().__init__()
50 | self.row_embed = nn.Embedding(50, num_pos_feats)
51 | self.col_embed = nn.Embedding(50, num_pos_feats)
52 | self.reset_parameters()
53 |
54 | def reset_parameters(self):
55 | nn.init.uniform_(self.row_embed.weight)
56 | nn.init.uniform_(self.col_embed.weight)
57 |
58 | def forward(self, tensor_list: NestedTensor):
59 | if torch.is_tensor(tensor_list):
60 | x = tensor_list
61 | else:
62 | x = tensor_list.tensors
63 | h, w = x.shape[-2:]
64 | i = torch.arange(w, device=x.device)
65 | j = torch.arange(h, device=x.device)
66 | x_emb = self.col_embed(i)
67 | y_emb = self.row_embed(j)
68 | pos = torch.cat([
69 | x_emb.unsqueeze(0).repeat(h, 1, 1),
70 | y_emb.unsqueeze(1).repeat(1, w, 1),
71 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
72 | return pos
73 |
74 | def build_position_encoding(args):
75 | N_steps = args.hidden_dim // 2
76 | if args.position_embedding in ('v2', 'sine'):
77 | # TODO find a better way of exposing other arguments
78 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
79 | elif args.position_embedding in ('v3', 'learned'):
80 | position_embedding = PositionEmbeddingLearned(N_steps)
81 | else:
82 | raise ValueError(f"not supported {args.position_embedding}")
83 |
84 | return position_embedding
85 |
86 | class Positionalencoding1d(nn.Module):
87 | def __init__(self, d_model, length):
88 | super().__init__()
89 | self.d_model = d_model
90 | self.length = length
91 |
92 | def __call__(self):
93 | if self.d_model % 2 != 0:
94 | raise ValueError("Cannot use sin/cos positional encoding with "
95 | "odd dim (got dim={:d})".format(self.d_model))
96 | pe = torch.zeros(self.length, self.d_model)
97 | position = torch.arange(0, self.length).unsqueeze(1)
98 | div_term = torch.exp((torch.arange(0, self.d_model, 2, dtype=torch.float) *
99 | -(math.log(10000.0) / self.d_model)))
100 | pe[:, 0::2] = torch.sin(position.float() * div_term)
101 | pe[:, 1::2] = torch.cos(position.float() * div_term)
102 | return pe
103 |
104 | class LearnablePositionalEncoding1D(nn.Module):
105 | def __init__(self, input_size, max_len=1000):
106 | super(LearnablePositionalEncoding1D, self).__init__()
107 |
108 | # Embedding layer for positional encoding
109 | self.positional_encoding = nn.Embedding(max_len, input_size)
110 |
111 | def forward(self, x):
112 | """
113 | Args:
114 | x: Input tensor of shape (batch_size, seq_len, input_size)
115 | Returns:
116 | Positionally encoded tensor of the same shape as input
117 | """
118 | batch_size, seq_len, _ = x.size()
119 |
120 | # Generate positional indices
121 | positions = torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1).to(x.device)
122 |
123 | # Get positional embeddings
124 | positional_embeddings = self.positional_encoding(positions)
125 |
126 | return positional_embeddings
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn, Tensor
3 | from typing import Optional
4 | import copy
5 | import torch
6 |
7 | class Transformer(nn.Module):
8 |
9 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=0,
10 | num_decoder_layers=3, dim_feedforward=2048, dropout=0.1,
11 | activation="relu", normalize_before=False,
12 | return_intermediate_dec=False, args=None):
13 | super().__init__()
14 |
15 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
16 | dropout, activation, normalize_before)
17 |
18 | decoder_norm = nn.LayerNorm(d_model)
19 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec)
20 | self._reset_parameters()
21 | self.d_model = d_model
22 | self.nhead = nhead
23 |
24 | def _reset_parameters(self):
25 | for p in self.parameters():
26 | if p.dim() > 1:
27 | nn.init.xavier_uniform_(p)
28 |
29 | def forward(self, src, query_embed, pos_embed, init_query=None, mask=None, tgt_mask=None):
30 | bs, c, h, w = src.shape
31 | src = src.flatten(2).permute(2, 0, 1)
32 | query_embed = query_embed.permute(1,0,2)
33 | tgt = torch.zeros_like(query_embed)
34 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
35 | hs = self.decoder(tgt, src, tgt_mask=tgt_mask, memory_key_padding_mask=mask,
36 | pos=pos_embed, query_pos=query_embed)
37 | return hs.transpose(1, 2), src.permute(1, 2, 0).view(bs, c, h, w)
38 |
39 |
40 |
41 | class TransformerEncoder(nn.Module):
42 |
43 | def __init__(self, encoder_layer, num_layers, norm=None):
44 | super().__init__()
45 | self.layers = _get_clones(encoder_layer, num_layers)
46 | self.num_layers = num_layers
47 | self.norm = norm
48 |
49 | def forward(self, src,
50 | mask: Optional[Tensor] = None,
51 | src_key_padding_mask: Optional[Tensor] = None,
52 | pos: Optional[Tensor] = None):
53 | output = src
54 |
55 | for layer in self.layers:
56 | output = layer(output, src_mask=mask,
57 | src_key_padding_mask=src_key_padding_mask, pos=pos)
58 |
59 | if self.norm is not None:
60 | output = self.norm(output)
61 |
62 | return output
63 |
64 |
65 | class TransformerDecoder(nn.Module):
66 |
67 | def __init__(self, decoder_layer, num_layers, norm=None, decoder_last_norm=None, return_intermediate=False):
68 | super().__init__()
69 | self.layers = _get_clones(decoder_layer, num_layers)
70 | self.num_layers = num_layers
71 | self.norm = norm
72 | self.return_intermediate = return_intermediate
73 |
74 | def forward(self, tgt, memory,
75 | tgt_mask: Optional[Tensor] = None,
76 | memory_mask: Optional[Tensor] = None,
77 | tgt_key_padding_mask: Optional[Tensor] = None,
78 | memory_key_padding_mask: Optional[Tensor] = None,
79 | pos: Optional[Tensor] = None,
80 | query_pos: Optional[Tensor] = None):
81 | output = tgt
82 | intermediate = []
83 | for layer in self.layers:
84 | output = layer(output, memory, tgt_mask=tgt_mask,
85 | memory_mask=memory_mask,
86 | tgt_key_padding_mask=tgt_key_padding_mask,
87 | memory_key_padding_mask=memory_key_padding_mask,
88 | pos=pos, query_pos=query_pos)
89 |
90 | if self.return_intermediate:
91 | intermediate.append(self.norm(output))
92 |
93 | if self.norm is not None:
94 | output = self.norm(output)
95 | if self.return_intermediate:
96 | intermediate.pop()
97 | intermediate.append(output)
98 |
99 | if self.return_intermediate:
100 | return torch.stack(intermediate)
101 |
102 | return output.unsqueeze(0)
103 |
104 |
105 | class TransformerEncoderLayer(nn.Module):
106 |
107 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
108 | activation="relu", normalize_before=False):
109 | super().__init__()
110 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
111 |
112 | self.linear1 = nn.Linear(d_model, dim_feedforward)
113 | self.dropout = nn.Dropout(dropout)
114 | self.linear2 = nn.Linear(dim_feedforward, d_model)
115 |
116 | self.norm1 = nn.LayerNorm(d_model)
117 | self.norm2 = nn.LayerNorm(d_model)
118 | self.dropout1 = nn.Dropout(dropout)
119 | self.dropout2 = nn.Dropout(dropout)
120 |
121 | self.activation = _get_activation_fn(activation)
122 | self.normalize_before = normalize_before
123 |
124 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
125 | return tensor if pos is None else tensor + pos
126 |
127 | def forward_post(self,
128 | src,
129 | src_mask: Optional[Tensor] = None,
130 | src_key_padding_mask: Optional[Tensor] = None,
131 | pos: Optional[Tensor] = None):
132 | q = k = self.with_pos_embed(src, pos)
133 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
134 | key_padding_mask=src_key_padding_mask)[0]
135 | src = src + self.dropout1(src2)
136 | src = self.norm1(src)
137 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
138 | src = src + self.dropout2(src2)
139 | src = self.norm2(src)
140 | return src
141 |
142 | def forward_pre(self, src,
143 | src_mask: Optional[Tensor] = None,
144 | src_key_padding_mask: Optional[Tensor] = None,
145 | pos: Optional[Tensor] = None):
146 | src2 = self.norm1(src)
147 | q = k = self.with_pos_embed(src2, pos)
148 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
149 | key_padding_mask=src_key_padding_mask)[0]
150 | src = src + self.dropout1(src2)
151 | src2 = self.norm2(src)
152 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
153 | src = src + self.dropout2(src2)
154 | return src
155 |
156 | def forward(self, src,
157 | src_mask: Optional[Tensor] = None,
158 | src_key_padding_mask: Optional[Tensor] = None,
159 | pos: Optional[Tensor] = None):
160 |
161 | if self.normalize_before:
162 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
163 | return self.forward_post(src, src_mask, src_key_padding_mask, pos)
164 |
165 |
166 |
167 |
168 | class TransformerDecoderLayer(nn.Module):
169 |
170 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
171 | activation="relu", normalize_before=False):
172 | super().__init__()
173 |
174 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
175 | # Implementation of Feedforward model
176 | self.linear1 = nn.Linear(d_model, dim_feedforward)
177 | self.dropout = nn.Dropout(dropout)
178 | self.linear2 = nn.Linear(dim_feedforward, d_model)
179 |
180 | self.norm1 = nn.LayerNorm(d_model)
181 | self.norm2 = nn.LayerNorm(d_model)
182 | self.norm3 = nn.LayerNorm(d_model)
183 |
184 | self.dropout1 = nn.Dropout(dropout)
185 | self.dropout2 = nn.Dropout(dropout)
186 | self.dropout3 = nn.Dropout(dropout)
187 |
188 | self.activation = _get_activation_fn(activation)
189 | self.normalize_before = normalize_before
190 |
191 |
192 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
193 | return tensor if pos is None else tensor + pos
194 |
195 | def forward_post(self, tgt, memory,
196 | tgt_mask: Optional[Tensor] = None,
197 | memory_mask: Optional[Tensor] = None,
198 | tgt_key_padding_mask: Optional[Tensor] = None,
199 | memory_key_padding_mask: Optional[Tensor] = None,
200 | pos: Optional[Tensor] = None,
201 | query_pos: Optional[Tensor] = None):
202 |
203 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
204 | key=self.with_pos_embed(memory, pos),
205 | value=memory, attn_mask=memory_mask,
206 | key_padding_mask=memory_key_padding_mask)[0]
207 | tgt = tgt + self.dropout2(tgt2)
208 | tgt = self.norm2(tgt)
209 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
210 | tgt = tgt + self.dropout3(tgt2)
211 | tgt = self.norm3(tgt)
212 | return tgt
213 |
214 | def forward_pre(self, tgt, memory,
215 | tgt_mask: Optional[Tensor] = None,
216 | memory_mask: Optional[Tensor] = None,
217 | tgt_key_padding_mask: Optional[Tensor] = None,
218 | memory_key_padding_mask: Optional[Tensor] = None,
219 | pos: Optional[Tensor] = None,
220 | query_pos: Optional[Tensor] = None):
221 | tgt2 = self.norm1(tgt)
222 | q = k = self.with_pos_embed(tgt2, query_pos)
223 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
224 | key_padding_mask=tgt_key_padding_mask)[0]
225 | tgt = tgt + self.dropout1(tgt2)
226 | tgt2 = self.norm2(tgt)
227 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
228 | key=self.with_pos_embed(memory, pos),
229 | value=memory, attn_mask=memory_mask,
230 | key_padding_mask=memory_key_padding_mask)[0]
231 | tgt = tgt + self.dropout2(tgt2)
232 | tgt2 = self.norm3(tgt)
233 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
234 | tgt = tgt + self.dropout3(tgt2)
235 | return tgt
236 |
237 | def forward(self, tgt, memory,
238 | tgt_mask: Optional[Tensor] = None,
239 | memory_mask: Optional[Tensor] = None,
240 | tgt_key_padding_mask: Optional[Tensor] = None,
241 | memory_key_padding_mask: Optional[Tensor] = None,
242 | pos: Optional[Tensor] = None,
243 | query_pos: Optional[Tensor] = None):
244 | if self.normalize_before:
245 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
246 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
247 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
248 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
249 |
250 |
251 | def _get_clones(module, N):
252 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
253 |
254 |
255 | def build_transformer(args):
256 | return Transformer(
257 | d_model=args.hidden_dim,
258 | dropout=args.dropout,
259 | nhead=args.nheads,
260 | dim_feedforward=args.dim_feedforward,
261 | num_encoder_layers=args.enc_layers,
262 | num_decoder_layers=args.dec_layers,
263 | normalize_before=args.pre_norm,
264 | return_intermediate_dec=True,
265 | args=args
266 | )
267 |
268 |
269 | def _get_activation_fn(activation):
270 | """Return an activation function given a string"""
271 | if activation == "relu":
272 | return F.relu
273 | if activation == "gelu":
274 | return F.gelu
275 | if activation == "glu":
276 | return F.glu
277 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
278 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | salesforce-lavis
2 | scikit-learn
3 | numpy==1.24.4
4 |
--------------------------------------------------------------------------------
/tools/launch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # --------------------------------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py
7 | # --------------------------------------------------------------------------------------------------------------------------
8 |
9 | r"""
10 | `torch.distributed.launch` is a module that spawns up multiple distributed
11 | training processes on each of the training nodes.
12 | The utility can be used for single-node distributed training, in which one or
13 | more processes per node will be spawned. The utility can be used for either
14 | CPU training or GPU training. If the utility is used for GPU training,
15 | each distributed process will be operating on a single GPU. This can achieve
16 | well-improved single-node training performance. It can also be used in
17 | multi-node distributed training, by spawning up multiple processes on each node
18 | for well-improved multi-node distributed training performance as well.
19 | This will especially be benefitial for systems with multiple Infiniband
20 | interfaces that have direct-GPU support, since all of them can be utilized for
21 | aggregated communication bandwidth.
22 | In both cases of single-node distributed training or multi-node distributed
23 | training, this utility will launch the given number of processes per node
24 | (``--nproc_per_node``). If used for GPU training, this number needs to be less
25 | or euqal to the number of GPUs on the current system (``nproc_per_node``),
26 | and each process will be operating on a single GPU from *GPU 0 to
27 | GPU (nproc_per_node - 1)*.
28 | **How to use this module:**
29 | 1. Single-Node multi-process distributed training
30 | ::
31 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
32 | YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
33 | arguments of your training script)
34 | 2. Multi-Node multi-process distributed training: (e.g. two nodes)
35 | Node 1: *(IP: 192.168.1.1, and has a free port: 1234)*
36 | ::
37 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
38 | --nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
39 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
40 | and all other arguments of your training script)
41 | Node 2:
42 | ::
43 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
44 | --nnodes=2 --node_rank=1 --master_addr="192.168.1.1"
45 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
46 | and all other arguments of your training script)
47 | 3. To look up what optional arguments this module offers:
48 | ::
49 | >>> python -m torch.distributed.launch --help
50 | **Important Notices:**
51 | 1. This utilty and multi-process distributed (single-node or
52 | multi-node) GPU training currently only achieves the best performance using
53 | the NCCL distributed backend. Thus NCCL backend is the recommended backend to
54 | use for GPU training.
55 | 2. In your training program, you must parse the command-line argument:
56 | ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
57 | If your training program uses GPUs, you should ensure that your code only
58 | runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
59 | Parsing the local_rank argument
60 | ::
61 | >>> import argparse
62 | >>> parser = argparse.ArgumentParser()
63 | >>> parser.add_argument("--local_rank", type=int)
64 | >>> args = parser.parse_args()
65 | Set your device to local rank using either
66 | ::
67 | >>> torch.cuda.set_device(arg.local_rank) # before your code runs
68 | or
69 | ::
70 | >>> with torch.cuda.device(arg.local_rank):
71 | >>> # your code to run
72 | 3. In your training program, you are supposed to call the following function
73 | at the beginning to start the distributed backend. You need to make sure that
74 | the init_method uses ``env://``, which is the only supported ``init_method``
75 | by this module.
76 | ::
77 | torch.distributed.init_process_group(backend='YOUR BACKEND',
78 | init_method='env://')
79 | 4. In your training program, you can either use regular distributed functions
80 | or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
81 | training program uses GPUs for training and you would like to use
82 | :func:`torch.nn.parallel.DistributedDataParallel` module,
83 | here is how to configure it.
84 | ::
85 | model = torch.nn.parallel.DistributedDataParallel(model,
86 | device_ids=[arg.local_rank],
87 | output_device=arg.local_rank)
88 | Please ensure that ``device_ids`` argument is set to be the only GPU device id
89 | that your code will be operating on. This is generally the local rank of the
90 | process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
91 | and ``output_device`` needs to be ``args.local_rank`` in order to use this
92 | utility
93 | 5. Another way to pass ``local_rank`` to the subprocesses via environment variable
94 | ``LOCAL_RANK``. This behavior is enabled when you launch the script with
95 | ``--use_env=True``. You must adjust the subprocess example above to replace
96 | ``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
97 | will not pass ``--local_rank`` when you specify this flag.
98 | .. warning::
99 | ``local_rank`` is NOT globally unique: it is only unique per process
100 | on a machine. Thus, don't use it to decide if you should, e.g.,
101 | write to a networked filesystem. See
102 | https://github.com/pytorch/pytorch/issues/12042 for an example of
103 | how things can go wrong if you don't do this correctly.
104 | """
105 |
106 |
107 | import sys
108 | import subprocess
109 | import os
110 | import socket
111 | from argparse import ArgumentParser, REMAINDER
112 |
113 | import torch
114 |
115 |
116 | def parse_args():
117 | """
118 | Helper function parsing the command line options
119 | @retval ArgumentParser
120 | """
121 | parser = ArgumentParser(description="PyTorch distributed training launch "
122 | "helper utilty that will spawn up "
123 | "multiple distributed processes")
124 |
125 | # Optional arguments for the launch helper
126 | parser.add_argument("--nnodes", type=int, default=1,
127 | help="The number of nodes to use for distributed "
128 | "training")
129 | parser.add_argument("--node_rank", type=int, default=0,
130 | help="The rank of the node for multi-node distributed "
131 | "training")
132 | parser.add_argument("--nproc_per_node", type=int, default=1,
133 | help="The number of processes to launch on each node, "
134 | "for GPU training, this is recommended to be set "
135 | "to the number of GPUs in your system so that "
136 | "each process can be bound to a single GPU.")
137 | parser.add_argument("--master_addr", default="127.0.0.1", type=str,
138 | help="Master node (rank 0)'s address, should be either "
139 | "the IP address or the hostname of node 0, for "
140 | "single node multi-proc training, the "
141 | "--master_addr can simply be 127.0.0.1")
142 | parser.add_argument("--master_port", default=29500, type=int,
143 | help="Master node (rank 0)'s free port that needs to "
144 | "be used for communciation during distributed "
145 | "training")
146 |
147 | # positional
148 | parser.add_argument("training_script", type=str,
149 | help="The full path to the single GPU training "
150 | "program/script to be launched in parallel, "
151 | "followed by all the arguments for the "
152 | "training script")
153 |
154 | # rest from the training program
155 | parser.add_argument('training_script_args', nargs=REMAINDER)
156 | return parser.parse_args()
157 |
158 |
159 | def main():
160 | args = parse_args()
161 |
162 | # world size in terms of number of processes
163 | dist_world_size = args.nproc_per_node * args.nnodes
164 |
165 | # set PyTorch distributed related environmental variables
166 | current_env = os.environ.copy()
167 | current_env["MASTER_ADDR"] = args.master_addr
168 | current_env["MASTER_PORT"] = str(args.master_port)
169 | current_env["WORLD_SIZE"] = str(dist_world_size)
170 |
171 | processes = []
172 |
173 | for local_rank in range(0, args.nproc_per_node):
174 | # each process's rank
175 | dist_rank = args.nproc_per_node * args.node_rank + local_rank
176 | current_env["RANK"] = str(dist_rank)
177 | current_env["LOCAL_RANK"] = str(local_rank)
178 |
179 | cmd = [args.training_script] + args.training_script_args
180 | #import pdb; pdb.set_trace()
181 |
182 | process = subprocess.Popen(cmd, env=current_env)
183 |
184 | processes.append(process)
185 | for process in processes:
186 | process.wait()
187 | if process.returncode != 0:
188 | raise subprocess.CalledProcessError(returncode=process.returncode,
189 | cmd=process.args)
190 |
191 |
192 | if __name__ == "__main__":
193 | main()
--------------------------------------------------------------------------------
/tools/run_dist_launch.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | GPUS=$1
3 | RUN_COMMAND=${@:2}
4 | if [ $GPUS -lt 8 ]; then
5 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS}
6 | else
7 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
8 | fi
9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
10 |
11 | # Generate a random port number between 20000 and 60000
12 | MASTER_PORT=$((20000 + RANDOM % 40000))
13 | NODE_RANK=${NODE_RANK:-0}
14 |
15 | let "NNODES=GPUS/GPUS_PER_NODE"
16 |
17 | python ./tools/launch.py \
18 | --nnodes ${NNODES} \
19 | --node_rank ${NODE_RANK} \
20 | --master_addr ${MASTER_ADDR} \
21 | --master_port ${MASTER_PORT} \
22 | --nproc_per_node ${GPUS_PER_NODE} \
23 | ${RUN_COMMAND}
24 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/SugaFormer/4c9219ed2b05a159751fc0390e599107f6f7f07e/util/__init__.py
--------------------------------------------------------------------------------