├── LICENSE ├── README.md ├── dataset ├── data.sh ├── data_loader.py ├── data_process.py ├── datascript.py └── refer.py ├── engine └── engine.py ├── model ├── model.py ├── modules.py ├── position_encoding.py └── transformer.py ├── requirements.txt ├── test.py ├── test.sh ├── train.py ├── train.sh └── utils ├── __init__.py ├── checkpoint.py ├── logger.py ├── losses.py ├── misc_utils.py ├── parsing_metrics.py ├── transforms.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Pengfei Yue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Selection based Referring Image Segmentation 2 | This is an official PyTorch implementation of ASDA (accepted by ACMMM 2024). 3 | ## News 4 | - [July 16, 2024] The paper is accepted by ACMMM 2024🎉. 5 | - [Oct 22, 2024] Pytorch implementation of ASDA is released. 6 | 7 | ## Main Results 8 | 9 | Main results on RefCOCO 10 | 11 | Model | Backbone | val | test A | test B | 12 | --- | ---- |:-------------:| :-----:|:-----:| 13 | CRIS| ResNet101 | 70.47 | 73.18 | 66.10 | 14 | ASDA| ViT-B | 75.06 | 77.14 | 71.36 | 15 | 16 | Main results on RefCOCO+ 17 | 18 | Model | Backbone | val | test A | test B | 19 | --- | ---- |:-------------:| :-----:|:-----:| 20 | CRIS| ResNet101 | 62.27 | 68.08 | 53.68 | 21 | ASDA| ViT-B | 66.84 | 71.13 | 57.83 | 22 | 23 | Main results on G-Ref 24 | 25 | Model | Backbone | val(U) | test(U) | val(G) | 26 | --- | ---- |:-------------:| :-----:| :-----:| 27 | CRIS| ResNet101 | 59.87 | 60.36 | - | 28 | ASDA| ViT-B | 65.73 | 66.45 | 63.55 29 | 30 | ## Quick Start 31 | ### Environment preparation 32 | ```bash 33 | conda create -n ASDA python=3.6 -y 34 | conda activate ASDA 35 | # install pytorch according to your cuda version 36 | # don't change version of torch, or it may occur conflict 37 | conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge 38 | 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | ### Dataset Preparation 43 | 44 | #### 1. Download the COCO train2014 to ASDA/ln_data/images. 45 | ```bash 46 | wget https://pjreddie.com/media/files/train2014.zip 47 | ``` 48 | 49 | #### 2. Download the RefCOCO, RefCOCO+, RefCOCOg to ASDA/ln_data. 50 | ```bash 51 | mkdir ln_data && cd ln_data 52 | # The original link bvisionweb1.cs.unc.edu/licheng/referit/data/refclef.zip is no longer valid, we have uploaded it to Google Drive (https://drive.google.com/file/d/1AnNBSL1gc9uG1zcdPIMg4d9e0y4dDSho/view?usp=sharing) 53 | wget 'https://drive.usercontent.google.com/download?id=1AnNBSL1gc9uG1zcdPIMg4d9e0y4dDSho&export=download&authuser=0&confirm=t&uuid=be656478-9669-4b58-ab23-39f196f88c07&at=AN_67v3n4xwkPBdEQ9pMlwonmhrH%3A1729591897703' -O refcoco_all.zip 54 | unzip refcoco_all.zip 55 | ``` 56 | 57 | #### 3. Run data.sh to generate the annotations. 58 | ```bash 59 | mkdir dataset && cd dataset 60 | bash data.sh 61 | ``` 62 | 63 | ### Training & Testing 64 | ```bash 65 | bash train.sh 0,1 66 | bash test.sh 0 67 | ``` 68 | ## License 69 | 70 | This project is under the MIT license. See [LICENSE](LICENSE) for details. 71 | 72 | ## Acknowledgement 73 | Thanks for a lot of codes from [CRIS](https://github.com/DerrickWang005/CRIS.pytorch), [VLT](https://github.com/henghuiding/Vision-Language-Transformer), [ViTDet](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet). 74 | 75 | ## Citation 76 | If you find our work useful in your research, please consider citing: 77 | ``` 78 | @inproceedings{yue2024adaptive, 79 | title={Adaptive Selection based Referring Image Segmentation}, 80 | author={Yue, Pengfei and Lin, Jianghang and Zhang, Shengchuan and Hu, Jie and Lu, Yilin and Niu, Hongwei and Ding, Haixin and Zhang, Yan and JIANG, GUANNAN and Cao, Liujuan and others}, 81 | booktitle={ACM Multimedia 2024} 82 | } 83 | ``` -------------------------------------------------------------------------------- /dataset/data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # data process 3 | python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcoco --split unc --generate_mask 4 | python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcoco+ --split unc --generate_mask 5 | python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcocog --split google --generate_mask 6 | python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcocog --split umd --generate_mask 7 | 8 | # datascript 9 | python datascript.py --dataset refcoco 10 | python datascript.py --dataset refcoco+ 11 | python datascript.py --dataset refcocog_google 12 | python datascript.py --dataset refcocog_umd 13 | -------------------------------------------------------------------------------- /dataset/data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset. 5 | """ 6 | import sys 7 | import cv2 8 | import torch 9 | import random 10 | import numpy as np 11 | import os.path as osp 12 | import torch.utils.data as data 13 | sys.path.append('.') 14 | import utils 15 | import re 16 | 17 | from pytorch_pretrained_bert.tokenization import BertTokenizer 18 | from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase 19 | import copy 20 | 21 | import clip 22 | 23 | sys.modules['utils'] = utils 24 | cv2.setNumThreads(0) 25 | 26 | def read_examples(input_line, unique_id): 27 | """Read a list of `InputExample`s from an input file.""" 28 | examples = [] 29 | # unique_id = 0 30 | line = input_line #reader.readline() 31 | # if not line: 32 | # break 33 | line = line.strip() 34 | text_a = None 35 | text_b = None 36 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 37 | if m is None: 38 | text_a = line 39 | else: 40 | text_a = m.group(1) #'man in black' 41 | text_b = m.group(2) 42 | 43 | examples.append( 44 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 45 | # unique_id += 1 46 | return examples 47 | 48 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 49 | while True: 50 | total_length = len(tokens_a) + len(tokens_b) 51 | if total_length <= max_length: 52 | break 53 | if len(tokens_a) > len(tokens_b): 54 | tokens_a.pop() 55 | else: 56 | tokens_b.pop() 57 | 58 | ## Bert text encoding 59 | class InputExample(object): 60 | def __init__(self, unique_id, text_a, text_b): 61 | self.unique_id = unique_id 62 | self.text_a = text_a 63 | self.text_b = text_b 64 | 65 | class InputFeatures(object): 66 | """A single set of features of data.""" 67 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 68 | self.unique_id = unique_id 69 | self.tokens = tokens 70 | self.input_ids = input_ids 71 | self.input_mask = input_mask 72 | self.input_type_ids = input_type_ids 73 | 74 | def convert_examples_to_features(examples, seq_length, tokenizer): 75 | """Loads a data file into a list of `InputBatch`s.""" 76 | features = [] 77 | for (ex_index, example) in enumerate(examples): 78 | tokens_a = tokenizer.tokenize(example.text_a) # ['far', 'left', 'vase'] 79 | 80 | tokens_b = None 81 | if example.text_b: 82 | tokens_b = tokenizer.tokenize(example.text_b) 83 | 84 | if tokens_b: 85 | # Modifies `tokens_a` and `tokens_b` in place so that the total 86 | # length is less than the specified length. 87 | # Account for [CLS], [SEP], [SEP] with "- 3" 88 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 89 | else: 90 | # Account for [CLS] and [SEP] with "- 2" 91 | if len(tokens_a) > seq_length - 2: 92 | tokens_a = tokens_a[0:(seq_length - 2)] 93 | tokens = [] 94 | input_type_ids = [] 95 | tokens.append("[CLS]") 96 | input_type_ids.append(0) 97 | for token in tokens_a: 98 | tokens.append(token) 99 | input_type_ids.append(0) 100 | tokens.append("[SEP]") 101 | input_type_ids.append(0) 102 | 103 | if tokens_b: 104 | for token in tokens_b: 105 | tokens.append(token) 106 | input_type_ids.append(1) 107 | tokens.append("[SEP]") 108 | input_type_ids.append(1) 109 | 110 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 111 | 112 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 113 | # tokens are attended to. 114 | input_mask = [1] * len(input_ids) 115 | 116 | # Zero-pad up to the sequence length. 117 | while len(input_ids) < seq_length: 118 | input_ids.append(0) 119 | input_mask.append(0) 120 | input_type_ids.append(0) 121 | 122 | assert len(input_ids) == seq_length 123 | assert len(input_mask) == seq_length 124 | assert len(input_type_ids) == seq_length 125 | features.append( 126 | InputFeatures( 127 | unique_id=example.unique_id, 128 | tokens=tokens, 129 | input_ids=input_ids, 130 | input_mask=input_mask, 131 | input_type_ids=input_type_ids)) 132 | return features 133 | 134 | class DatasetNotFoundError(Exception): 135 | pass 136 | 137 | class ReferDataset(data.Dataset): 138 | SUPPORTED_DATASETS = { 139 | 'refcoco': { 140 | 'splits': ('train', 'val', 'testA', 'testB'), 141 | 'params': {'dataset': 'refcoco', 'split_by': 'unc'} 142 | }, 143 | 'refcoco+': { 144 | 'splits': ('train', 'val', 'testA', 'testB'), 145 | 'params': {'dataset': 'refcoco+', 'split_by': 'unc'} 146 | }, 147 | 'refcocog': { 148 | 'splits': ('train', 'val', 'test'), 149 | 'params': {'dataset': 'refcocog', 'split_by': 'unc'} 150 | }, 151 | 'refcocog_g': { 152 | 'splits': ('train', 'val'), 153 | 'params': {'dataset': 'refcocog', 'split_by': 'google'} 154 | }, 155 | 'refcocog_u': { 156 | 'splits': ('train', 'val', 'test'), 157 | 'params': {'dataset': 'refcocog', 'split_by': 'unc'} 158 | }, 159 | 'grefcoco': { 160 | 'splits': ('train', 'val', 'testA', 'testB'), 161 | 'params': {'dataset': 'grefcoco', 'split_by': 'unc'} 162 | } 163 | } 164 | 165 | def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, 166 | transform=None, augment=False, split='train', max_query_len=128, 167 | bert_model='bert-base-uncased'): 168 | self.images = [] 169 | self.data_root = data_root 170 | self.split_root = split_root 171 | self.dataset = dataset 172 | self.imsize = imsize 173 | self.query_len = max_query_len 174 | self.transform = transform 175 | self.split = split 176 | self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True) # should be true for English 177 | self.augment=augment 178 | 179 | valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits'] 180 | 181 | if split not in valid_splits: 182 | raise ValueError( 183 | 'Dataset {0} does not have split {1}'.format( 184 | self.dataset, split)) 185 | 186 | self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt') 187 | self.mask_root = osp.join(self.data_root, 'masks', self.dataset) 188 | self.im_dir = osp.join(self.data_root, 'images', 'train2014') 189 | 190 | dataset_path = osp.join(self.split_root, self.dataset) 191 | splits = [split] 192 | for split in splits: 193 | imgset_file = '{0}_{1}.pth'.format(self.dataset, split) 194 | imgset_path = osp.join(dataset_path, imgset_file) 195 | self.images += torch.load(imgset_path) 196 | 197 | def exists_dataset(self): 198 | return osp.exists(osp.join(self.split_root, self.dataset)) 199 | 200 | def pull_item(self, idx): 201 | img_file, seg_id, bbox, phrase = self.images[idx] 202 | bbox = np.array(bbox, dtype=int) # x1y1x2y2 203 | 204 | img_path = osp.join(self.im_dir, img_file) 205 | img = cv2.imread(img_path) # BGR [512, 640, 3] 206 | ## duplicate channel if gray image 207 | if img.shape[-1] > 1: 208 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB 209 | else: 210 | img = np.stack([img] * 3) 211 | 212 | ## seg map 213 | seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640] 214 | seg_map = np.array(seg_map).astype(np.float32) 215 | return img, phrase, bbox, seg_map 216 | 217 | def __len__(self): 218 | return len(self.images) 219 | 220 | def __getitem__(self, idx): 221 | img, phrase, bbox, seg_map = self.pull_item(idx) 222 | phrase = phrase.lower() 223 | if self.augment: 224 | augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \ 225 | True, True, True, False, False, False 226 | 227 | ## seems a bug in torch transformation resize, so separate in advance 228 | h,w = img.shape[0], img.shape[1] 229 | # print("img.shape", img.shape) 230 | if self.augment: 231 | ## random horizontal flip 232 | if augment_flip and random.random() > 0.5: 233 | img = cv2.flip(img, 1) 234 | seg_map = cv2.flip(seg_map, 1) 235 | bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1 236 | phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left') 237 | 238 | ## random copy and add left or right 239 | if augment_copy: 240 | img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox) 241 | 242 | ## random erase for occluded 243 | if augment_erase: 244 | img, seg_map = random_erase(img, seg_map) 245 | 246 | ## random padding and crop 247 | if augment_crop: 248 | img, seg_map = random_crop(img, seg_map, 40, h, w) 249 | 250 | ## random intensity, saturation change 251 | if augment_hsv: 252 | fraction = 0.50 253 | img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV) 254 | S = img_hsv[:, :, 1].astype(np.float32) 255 | V = img_hsv[:, :, 2].astype(np.float32) 256 | a = (random.random() * 2 - 1) * fraction + 1 257 | if a > 1: 258 | np.clip(S, a_min=0, a_max=255, out=S) 259 | a = (random.random() * 2 - 1) * fraction + 1 260 | V *= a 261 | if a > 1: 262 | np.clip(V, a_min=0, a_max=255, out=V) 263 | 264 | img_hsv[:, :, 1] = S.astype(np.uint8) 265 | img_hsv[:, :, 2] = V.astype(np.uint8) 266 | img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB) 267 | 268 | img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize) 269 | bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw 270 | bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh 271 | 272 | ## random affine transformation 273 | if augment_affine: 274 | img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \ 275 | degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill 276 | 277 | else: ## should be inference, or specified training 278 | img, _, ratio, dw, dh = letterbox(img, None, self.imsize) 279 | bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw 280 | bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh 281 | 282 | draw_img = copy.deepcopy(img) 283 | # Norm, to tensor 284 | if self.transform is not None: 285 | img = self.transform(img) 286 | 287 | ## encode phrase to clip input 288 | word_id = clip.tokenize(phrase, 17, truncate=True) 289 | word_mask = ~ (word_id == 0) 290 | 291 | if self.augment: # train 292 | seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208) 293 | seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]]) 294 | return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \ 295 | np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32) 296 | else: 297 | seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]]) 298 | return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \ 299 | np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \ 300 | np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8) 301 | -------------------------------------------------------------------------------- /dataset/data_process.py: -------------------------------------------------------------------------------- 1 | # encoding=utf8 2 | # %matplotlib inline 3 | import numpy as np 4 | import os 5 | from refer import REFER 6 | import os.path as osp 7 | import cv2 8 | import argparse 9 | parser = argparse.ArgumentParser(description='Data preparation') 10 | parser.add_argument('--data_root', type=str) # contains refclef, refcoco, refcoco+, refcocog and images 11 | parser.add_argument('--output_dir', type=str) 12 | parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog'], default='refcoco') 13 | parser.add_argument('--split', type=str,default='umd') 14 | parser.add_argument('--generate_mask', action='store_true') 15 | args = parser.parse_args() 16 | # data_root # contains refclef, refcoco, refcoco+, refcocog and images 17 | refer = REFER(args.data_root, args.dataset, args.split) 18 | 19 | print ('dataset [%s_%s] contains: ' % (args.dataset, args.split)) 20 | ref_ids = refer.getRefIds() 21 | image_ids = refer.getImgIds() 22 | print ('%s expressions for %s refs in %s images.' % (len(refer.Sents), len(ref_ids), len(image_ids))) 23 | 24 | print('\nAmong them:') 25 | if args.dataset == 'refclef': 26 | if args.split == 'unc': 27 | splits = ['train', 'val', 'testA','testB','testC'] 28 | else: 29 | splits = ['train', 'val', 'test'] 30 | elif args.dataset == 'refcoco': 31 | splits = ['train', 'val', 'testA', 'testB'] 32 | elif args.dataset == 'refcoco+': 33 | splits = ['train', 'val', 'testA', 'testB'] 34 | elif args.dataset == 'grefcoco': 35 | splits = ['train', 'val', 'testA', 'testB'] 36 | elif args.dataset == 'refcocog': 37 | splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now. 38 | 39 | 40 | 41 | # split data as a type in splits list 42 | for split in splits: 43 | ref_ids = refer.getRefIds(split=split) 44 | print('%s refs are in split [%s].' % (len(ref_ids), split)) 45 | 46 | 47 | # show a batch data with bounding box,cat,sentences 48 | def show_a_batch(batch_size): 49 | split='train' 50 | # batch_size=32 51 | ref_ids = refer.getRefIds(split=split) 52 | print(split+'_size:',np.alen(ref_ids)) 53 | batch_index=list(np.random.choice(np.alen(ref_ids),batch_size)) 54 | 55 | # print(refer.Refs) 56 | ref_id = [ref_ids[i] for i in batch_index] 57 | refs = [refer.Refs[i] for i in ref_id] 58 | bboxs=[refer.getRefBox(i) for i in ref_id] 59 | sentences=[ref['sentences'] for ref in refs] 60 | image_urls=[refer.loadImgs(image_ids=ref['image_id']) for ref in refs] 61 | cats = [refer.loadCats(cat_ids=ref['category_id']) for ref in refs] 62 | # plt.figure() 63 | # plt.subplot(batch_size) 64 | grid_width = 2 65 | grid_height = int(batch_size / grid_width) 66 | # fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*10, 10*grid_height)) 67 | for i in range(batch_size): 68 | print('bbox for batch[{}]:'.format(i),bboxs[i]) 69 | print('sentences for batch[{}]:'.format(i)) 70 | for sid, sent in enumerate(sentences[i]): 71 | print('%s. %s' % (sid+1, sent['sent'])) 72 | print('cats for batch[{}]:'.format(i), cats[i]) 73 | 74 | image_url=image_urls[i][0] 75 | image=cv2.imread(osp.join(refer.IMAGE_DIR, image_url['file_name'])) 76 | print(image.shape) 77 | # print(bboxs[i][0]) 78 | cv2.rectangle(image,(int(bboxs[i][0]), int(bboxs[i][1])), (int(bboxs[i][0]+bboxs[i][2]),int(bboxs[i][1]+ bboxs[i][3])),255,3) 79 | cv2.putText(image, 80 | str(sent['sent']), 81 | (20, 20), 82 | cv2.FONT_HERSHEY_SIMPLEX, 83 | .9,(0,255,0), 2) 84 | os.mkdir('debug_vis') 85 | cv2.imwrite('./debug_vis/'+image_url['file_name'], image) 86 | cv2.imwrite('./debug_vis/mask'+image_url['file_name'], refer.getMask(refs[i])['mask']*255) 87 | # ax.imshow(image) 88 | # plt.show() 89 | 90 | def cat_process(cat): 91 | if cat >= 1 and cat <= 11: 92 | cat = cat - 1 93 | elif cat >= 13 and cat <= 25: 94 | cat = cat - 2 95 | elif cat >= 27 and cat <= 28: 96 | cat = cat - 3 97 | elif cat >= 31 and cat <= 44: 98 | cat = cat - 5 99 | elif cat >= 46 and cat <= 65: 100 | cat = cat - 6 101 | elif cat == 67: 102 | cat = cat - 7 103 | elif cat == 70: 104 | cat = cat - 9 105 | elif cat >= 72 and cat <= 82: 106 | cat = cat - 10 107 | elif cat >= 84 and cat <= 90: 108 | cat = cat - 11 109 | return cat 110 | 111 | def bbox_process(bbox,cat,segement_id): 112 | x_min = int(bbox[0]) 113 | y_min = int(bbox[1]) 114 | x_max = x_min + int(bbox[2]) 115 | y_max = y_min + int(bbox[3]) 116 | box_info = " %d,%d,%d,%d,%d,%d" % (int(x_min), int(y_min), int(x_max), int(y_max), int(cat),int(segement_id)) 117 | return box_info 118 | 119 | def prepare_dataset(dataset,splits,output_dir,generate_mask=False): 120 | # split_type='train' 121 | # splits=[split_type] 122 | # batch_size=32 123 | if dataset == 'refcocog': 124 | dataset = 'refcocog_' + args.split 125 | if not os.path.exists(os.path.join(output_dir,'anns',dataset)): 126 | os.makedirs(os.path.join(output_dir,'anns',dataset)) 127 | if not os.path.exists(os.path.join(output_dir,'masks',dataset)): 128 | os.makedirs(os.path.join(output_dir,'masks',dataset)) 129 | for split in splits: 130 | f = open(os.path.join(output_dir,'anns', dataset, split + '.txt'), 'w', encoding='utf-8') 131 | # print(split) 132 | split_num=0 133 | ll=0 134 | ref_ids = refer.getRefIds(split=split) 135 | print(split+'_size:',np.alen(ref_ids)) 136 | for i in ref_ids: 137 | # ref_id = ref_ids[i] 138 | refs = refer.Refs[i] 139 | bboxs=refer.getRefBox(i) 140 | print("bboxs", bboxs) 141 | sentences=refs['sentences'] 142 | image_urls=refer.loadImgs(image_ids=refs['image_id'])[0] 143 | 144 | # grefcoco中的category_id是一个list 145 | cat = refs['category_id'] 146 | if type(cat) == list: 147 | for j in range(len(cat)): 148 | cat[j] = cat_process(cat[j]) 149 | else: 150 | cat = cat_process(cat) 151 | 152 | image_urls=image_urls['file_name'] 153 | if dataset=='refclef' and image_urls in ['19579.jpg', '17975.jpg', '19575.jpg']: 154 | continue 155 | # RES中box信息和cat信息使用不到 156 | if type(bboxs[0]) == list: 157 | box_info = bbox_process(bboxs[0], cat[0], i) # add segement id 158 | else: 159 | box_info=bbox_process(bboxs,cat,i) #add segement id 160 | f.write(image_urls) 161 | f.write(box_info) 162 | # f.write(' '+str(i)) 163 | if generate_mask: 164 | if dataset == 'grefcoco': 165 | np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMaskByRef(refs, merge=True)['mask']) 166 | else: 167 | np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMask(refs)['mask']) #if need seg mask ,set it! 168 | for sentence in sentences: 169 | f.write(' ~ ') 170 | # print(sentence['sent'].encode('UTF-8')) 171 | f.write(sentence['sent']) 172 | if llmax_len: 216 | # max_len=len(line[sent_stop:i]) 217 | sent_stop = i + 1 218 | for i in range(50): 219 | if word_l_count[i]>0: 220 | print('length:%d'%i,',count:%d'%word_l_count[i]) 221 | # print('max_len:',max_len) 222 | # print(len(lines)) 223 | 224 | 225 | prepare_dataset(args.dataset,splits,args.output_dir,args.generate_mask) -------------------------------------------------------------------------------- /dataset/datascript.py: -------------------------------------------------------------------------------- 1 | # generate **.pth 2 | import os 3 | import sys 4 | import torch 5 | sys.path.append('.') 6 | 7 | import argparse 8 | parser = argparse.ArgumentParser(description='Data preparation') 9 | parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog_google', 'refcocog_umd'], default='refcoco') 10 | args = parser.parse_args() 11 | 12 | def main(args): 13 | dataset = args.dataset 14 | input_txt_list = os.listdir(f'../ln_data/anns/{dataset}') 15 | if not os.path.exists(f'../data/{dataset}'): 16 | os.makedirs(f'../data/{dataset}') 17 | for input_txt in input_txt_list: 18 | split = input_txt.split('_')[-1].split('.')[0] 19 | input_txt = os.path.join('../ln_data/anns', dataset, input_txt) 20 | res = [] 21 | with open(input_txt, encoding='utf-8') as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | line = line.split() 25 | stop = len(line) 26 | img_name = line[0] 27 | for i in range(1,len(line)): 28 | if (line[i]=='~'): 29 | stop=i 30 | break 31 | box_ = [list(map(int,box.split(','))) for box in line[1:stop]] 32 | box = box_[0][:4] 33 | seg_id=box_[0][-1] 34 | 35 | sent_stop=stop+1 36 | for i in range(stop+1,len(line)): 37 | if line[i]=='~': 38 | des = '' 39 | for word in line[sent_stop:i]: 40 | des = des + word + ' ' 41 | sent_stop=i+1 42 | des = des.rstrip(' ') 43 | res.append((img_name, seg_id, box, des)) 44 | des = '' 45 | for word in line[sent_stop:len(line)]: 46 | des = des + word + ' ' 47 | des = des.rstrip(' ') 48 | res.append((img_name, seg_id, box, des)) 49 | # print(res) 50 | 51 | imgset_path = '{0}_{1}.pth'.format(dataset, split) 52 | images = torch.save(res, os.path.join("../data", dataset, imgset_path)) 53 | print(dataset, " done") 54 | 55 | if __name__ == "__main__": 56 | main(args) 57 | -------------------------------------------------------------------------------- /dataset/refer.py: -------------------------------------------------------------------------------- 1 | __author__ = 'licheng' 2 | 3 | """ 4 | This interface provides access to four datasets: 5 | 1) refclef 6 | 2) refcoco 7 | 3) refcoco+ 8 | 4) refcocog 9 | split by unc and google 10 | The following API functions are defined: 11 | REFER - REFER api class 12 | getRefIds - get ref ids that satisfy given filter conditions. 13 | getAnnIds - get ann ids that satisfy given filter conditions. 14 | getImgIds - get image ids that satisfy given filter conditions. 15 | getCatIds - get category ids that satisfy given filter conditions. 16 | loadRefs - load refs with the specified ref ids. 17 | loadAnns - load anns with the specified ann ids. 18 | loadImgs - load images with the specified image ids. 19 | loadCats - load category names with the specified category ids. 20 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id 21 | showRef - show image, segmentation or box of the referred object with the ref 22 | getMask - get mask and area of the referred object given ref 23 | showMask - show mask of the referred object given ref 24 | """ 25 | 26 | import sys 27 | import os.path as osp 28 | import os 29 | import json 30 | # import _pickle as pickle 31 | import pickle 32 | import time 33 | import itertools 34 | import skimage.io as io 35 | import matplotlib.pyplot as plt 36 | from matplotlib.collections import PatchCollection 37 | from matplotlib.patches import Polygon, Rectangle 38 | from pprint import pprint 39 | import numpy as np 40 | from pycocotools import mask 41 | import cv2 42 | # from skimage.measure import label, regionprops 43 | 44 | class REFER: 45 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'): 46 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog 47 | # also provide dataset name and splitBy information 48 | # e.g., dataset = 'refcoco', splitBy = 'unc' 49 | print('loading dataset %s into memory...' % dataset) 50 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) 51 | self.DATA_DIR = osp.join(data_root, dataset) 52 | if dataset in ['refcoco', 'refcoco+', 'refcocog']: 53 | self.IMAGE_DIR = osp.join(data_root, 'images/train2014') 54 | elif dataset == 'refclef': 55 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') 56 | else: 57 | print('No refer dataset is called [%s]' % dataset) 58 | sys.exit() 59 | 60 | # load refs from data/dataset/refs(dataset).json 61 | tic = time.time() 62 | ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p') 63 | self.data = {} 64 | self.data['dataset'] = dataset 65 | 66 | self.data['refs'] = pickle.load(open(ref_file, 'rb'),fix_imports=True) 67 | 68 | # load annotations from data/dataset/instances.json 69 | instances_file = osp.join(self.DATA_DIR, 'instances.json') 70 | instances = json.load(open(instances_file, 'r')) 71 | self.data['images'] = instances['images'] 72 | self.data['annotations'] = instances['annotations'] 73 | self.data['categories'] = instances['categories'] 74 | 75 | # create index 76 | self.createIndex() 77 | print('DONE (t=%.2fs)' % (time.time()-tic)) 78 | 79 | def createIndex(self): 80 | # create sets of mapping 81 | # 1) Refs: {ref_id: ref} 82 | # 2) Anns: {ann_id: ann} 83 | # 3) Imgs: {image_id: image} 84 | # 4) Cats: {category_id: category_name} 85 | # 5) Sents: {sent_id: sent} 86 | # 6) imgToRefs: {image_id: refs} 87 | # 7) imgToAnns: {image_id: anns} 88 | # 8) refToAnn: {ref_id: ann} 89 | # 9) annToRef: {ann_id: ref} 90 | # 10) catToRefs: {category_id: refs} 91 | # 11) sentToRef: {sent_id: ref} 92 | # 12) sentToTokens: {sent_id: tokens} 93 | print('creating index...') 94 | # fetch info from instances 95 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} 96 | for ann in self.data['annotations']: 97 | Anns[ann['id']] = ann 98 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann] 99 | for img in self.data['images']: 100 | Imgs[img['id']] = img 101 | for cat in self.data['categories']: 102 | Cats[cat['id']] = cat['name'] 103 | 104 | # fetch info from refs 105 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} 106 | Sents, sentToRef, sentToTokens = {}, {}, {} 107 | for ref in self.data['refs']: 108 | # ids 109 | ref_id = ref['ref_id'] 110 | ann_id = ref['ann_id'] 111 | category_id = ref['category_id'] 112 | image_id = ref['image_id'] 113 | 114 | # add mapping related to ref 115 | Refs[ref_id] = ref 116 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] 117 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] 118 | refToAnn[ref_id] = Anns[ann_id] 119 | annToRef[ann_id] = ref 120 | 121 | # add mapping of sent 122 | for sent in ref['sentences']: 123 | Sents[sent['sent_id']] = sent 124 | sentToRef[sent['sent_id']] = ref 125 | sentToTokens[sent['sent_id']] = sent['tokens'] 126 | 127 | # create class members 128 | self.Refs = Refs 129 | self.Anns = Anns 130 | self.Imgs = Imgs 131 | self.Cats = Cats 132 | self.Sents = Sents 133 | self.imgToRefs = imgToRefs 134 | self.imgToAnns = imgToAnns 135 | self.refToAnn = refToAnn 136 | self.annToRef = annToRef 137 | self.catToRefs = catToRefs 138 | self.sentToRef = sentToRef 139 | self.sentToTokens = sentToTokens 140 | print('index created.') 141 | 142 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): 143 | image_ids = image_ids if type(image_ids) == list else [image_ids] 144 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 145 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 146 | 147 | if len(image_ids)==len(cat_ids)==len(ref_ids)==len(split)==0: 148 | refs = self.data['refs'] 149 | else: 150 | if not len(image_ids) == 0: 151 | refs = [self.imgToRefs[image_id] for image_id in image_ids] 152 | else: 153 | refs = self.data['refs'] 154 | if not len(cat_ids) == 0: 155 | refs = [ref for ref in refs if ref['category_id'] in cat_ids] 156 | if not len(ref_ids) == 0: 157 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids] 158 | if not len(split) == 0: 159 | if split in ['testA', 'testB', 'testC']: 160 | refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ... 161 | elif split in ['testAB', 'testBC', 'testAC']: 162 | refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess... 163 | elif split == 'test': 164 | refs = [ref for ref in refs if 'test' in ref['split']] 165 | elif split == 'train' or split == 'val': 166 | refs = [ref for ref in refs if ref['split'] == split] 167 | else: 168 | print('No such split [%s]' % split) 169 | sys.exit() 170 | ref_ids = [ref['ref_id'] for ref in refs] 171 | return ref_ids 172 | 173 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): 174 | image_ids = image_ids if type(image_ids) == list else [image_ids] 175 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 176 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 177 | 178 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: 179 | ann_ids = [ann['id'] for ann in self.data['annotations']] 180 | else: 181 | if not len(image_ids) == 0: 182 | lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns] 183 | anns = list(itertools.chain.from_iterable(lists)) 184 | else: 185 | anns = self.data['annotations'] 186 | if not len(cat_ids) == 0: 187 | anns = [ann for ann in anns if ann['category_id'] in cat_ids] 188 | ann_ids = [ann['id'] for ann in anns] 189 | if not len(ref_ids) == 0: 190 | ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) 191 | return ann_ids 192 | 193 | def getImgIds(self, ref_ids=[]): 194 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 195 | 196 | if not len(ref_ids) == 0: 197 | image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) 198 | else: 199 | image_ids = self.Imgs.keys() 200 | return image_ids 201 | 202 | def getCatIds(self): 203 | return self.Cats.keys() 204 | 205 | def loadRefs(self, ref_ids=[]): 206 | if type(ref_ids) == list: 207 | return [self.Refs[ref_id] for ref_id in ref_ids] 208 | elif type(ref_ids) == int: 209 | return [self.Refs[ref_ids]] 210 | 211 | def loadAnns(self, ann_ids=[]): 212 | if type(ann_ids) == list: 213 | return [self.Anns[ann_id] for ann_id in ann_ids] 214 | elif type(ann_ids) == int or type(ann_ids) == unicode: 215 | return [self.Anns[ann_ids]] 216 | 217 | def loadImgs(self, image_ids=[]): 218 | if type(image_ids) == list: 219 | return [self.Imgs[image_id] for image_id in image_ids] 220 | elif type(image_ids) == int: 221 | return [self.Imgs[image_ids]] 222 | 223 | def loadCats(self, cat_ids=[]): 224 | if type(cat_ids) == list: 225 | return [self.Cats[cat_id] for cat_id in cat_ids] 226 | elif type(cat_ids) == int: 227 | return [self.Cats[cat_ids]] 228 | 229 | def getRefBox(self, ref_id): 230 | ref = self.Refs[ref_id] 231 | ann = self.refToAnn[ref_id] 232 | return ann['bbox'] # [x, y, w, h] 233 | 234 | def showRef(self, ref, seg_box='seg'): 235 | ax = plt.gca() 236 | # show image 237 | image = self.Imgs[ref['image_id']] 238 | I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) 239 | ax.imshow(I) 240 | # show refer expression 241 | for sid, sent in enumerate(ref['sentences']): 242 | print('%s. %s' % (sid+1, sent['sent'])) 243 | # show segmentations 244 | if seg_box == 'seg': 245 | ann_id = ref['ann_id'] 246 | ann = self.Anns[ann_id] 247 | polygons = [] 248 | color = [] 249 | c = 'none' 250 | if type(ann['segmentation'][0]) == list: 251 | # polygon used for refcoco* 252 | for seg in ann['segmentation']: 253 | poly = np.array(seg).reshape((len(seg)//2, 2)) 254 | polygons.append(Polygon(poly, True, alpha=0.4)) 255 | color.append(c) 256 | p = PatchCollection(polygons, facecolors=color, edgecolors=(1,1,0,0), linewidths=3, alpha=1) 257 | ax.add_collection(p) # thick yellow polygon 258 | p = PatchCollection(polygons, facecolors=color, edgecolors=(1,0,0,0), linewidths=1, alpha=1) 259 | ax.add_collection(p) # thin red polygon 260 | else: 261 | # mask used for refclef 262 | rle = ann['segmentation'] 263 | m = mask.decode(rle) 264 | img = np.ones( (m.shape[0], m.shape[1], 3) ) 265 | color_mask = np.array([2.0,166.0,101.0])/255 266 | for i in range(3): 267 | img[:,:,i] = color_mask[i] 268 | ax.imshow(np.dstack( (img, m*0.5) )) 269 | # show bounding-box 270 | elif seg_box == 'box': 271 | ann_id = ref['ann_id'] 272 | print(ann_id) 273 | ann = self.Anns[ann_id] 274 | bbox = self.getRefBox(ref['ref_id']) 275 | box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3) 276 | ax.add_patch(box_plot) 277 | 278 | def getMask(self, ref): 279 | # return mask, area and mask-center 280 | ann = self.refToAnn[ref['ref_id']] 281 | print(ann) 282 | image = self.Imgs[ref['image_id']] 283 | if type(ann['segmentation'][0]) == list: # polygon 284 | rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width']) 285 | else: 286 | rle = ann['segmentation'] 287 | 288 | # for i in range(len(rle['counts'])): 289 | # print(rle) 290 | m = mask.decode(rle) 291 | m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) 292 | m = m.astype(np.uint8) # convert to np.uint8 293 | # compute area 294 | area = sum(mask.area(rle)) # should be close to ann['area'] 295 | return {'mask': m, 'area': area} 296 | # # position 297 | # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style) 298 | # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style) 299 | # # mass position (if there were multiple regions, we use the largest one.) 300 | # label_m = label(m, connectivity=m.ndim) 301 | # regions = regionprops(label_m) 302 | # if len(regions) > 0: 303 | # largest_id = np.argmax(np.array([props.filled_area for props in regions])) 304 | # largest_props = regions[largest_id] 305 | # mass_y, mass_x = largest_props.centroid 306 | # else: 307 | # mass_x, mass_y = position_x, position_y 308 | # # if centroid is not in mask, we find the closest point to it from mask 309 | # if m[mass_y, mass_x] != 1: 310 | # print 'Finding closes mask point ...' 311 | # kernel = np.ones((10, 10),np.uint8) 312 | # me = cv2.erode(m, kernel, iterations = 1) 313 | # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style 314 | # points = np.array(points) 315 | # dist = np.sum((points - (mass_y, mass_x))**2, axis=1) 316 | # id = np.argsort(dist)[0] 317 | # mass_y, mass_x = points[id] 318 | # # return 319 | # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y} 320 | # # show image and mask 321 | # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) 322 | # plt.figure() 323 | # plt.imshow(I) 324 | # ax = plt.gca() 325 | # img = np.ones( (m.shape[0], m.shape[1], 3) ) 326 | # color_mask = np.array([2.0,166.0,101.0])/255 327 | # for i in range(3): 328 | # img[:,:,i] = color_mask[i] 329 | # ax.imshow(np.dstack( (img, m*0.5) )) 330 | # plt.show() 331 | 332 | def showMask(self, ref): 333 | M = self.getMask(ref) 334 | msk = M['mask'] 335 | ax = plt.gca() 336 | ax.imshow(msk) 337 | 338 | 339 | if __name__ == '__main__': 340 | refer = REFER(data_root="/home/ypf/workspace/code/BKINet/ln_data", dataset='refcoco', splitBy='unc') 341 | save_path = "./visualization/" 342 | ref_ids = refer.getRefIds() 343 | print(len(ref_ids)) 344 | 345 | print(len(refer.Imgs)) 346 | print(len(refer.imgToRefs)) 347 | print(refer.Cats) 348 | 349 | ref_ids = refer.getRefIds(split='train') 350 | print('There are %s training referred objects.' % len(ref_ids)) 351 | 352 | img_ids = [8936, 52563] 353 | # ref_ids = refer.getRefIds(image_ids=img_ids) 354 | 355 | # refs = refer.loadRefs(ref_ids) 356 | 357 | def custom_vis1(image, mask_): 358 | # 将mask应用到蓝色图层 359 | # 创建一个蓝色图层 360 | blue_layer = np.zeros_like(image) 361 | blue_layer[:, :, 0] = 255 # 对于OpenCV,蓝色通道是第一个 362 | blue_mask = cv2.bitwise_and(blue_layer, blue_layer, mask=mask_) 363 | 364 | # 将蓝色mask以一定的透明度覆盖到原图上 365 | alpha = 0.1 # 透明度 366 | cv2.addWeighted(blue_mask, alpha, image, 1 - alpha, 0, image) 367 | 368 | def custom_vis2(image, mask_): 369 | # 创建蓝色图层 370 | blue_layer = np.zeros_like(image) 371 | blue_layer[:, :, 0] = 255 # 对于OpenCV,蓝色通道是第一个 372 | 373 | # 将mask应用到蓝色图层 374 | blue_mask = cv2.bitwise_and(blue_layer, blue_layer, mask=mask_) 375 | 376 | # alpha值定义了mask图层和原图的融合程度 377 | alpha = 0.5 # 透明度 378 | 379 | # 创建一个完全透明的图层 380 | transparent_layer = np.zeros_like(image) 381 | 382 | # 我们只在mask的区域上应用蓝色图层,并调整alpha值来控制透明度 383 | for i in range(3): # 只处理RGB三个通道 384 | transparent_layer[:, :, i] = cv2.addWeighted(blue_mask[:, :, i], alpha, image[:, :, i], 1 - alpha, 0) 385 | 386 | # 在mask区域外使用原图 387 | transparent_layer[mask_ == 0] = image[mask_ == 0] 388 | 389 | return transparent_layer 390 | 391 | def custom_vis3(image, mask_): 392 | """ 393 | 直接在原图上修改指定mask区域的颜色为蓝色 394 | 不改变其他区域的亮度或色彩 395 | """ 396 | image[mask_ != 0] = [255, 0, 0] # OpenCV中的颜色顺序是BGR 397 | 398 | def custom_vis4(image, mask_, alpha=0.4): 399 | """ 400 | 在原图上以指定的透明度应用蓝色遮罩。 401 | alpha: 遮罩的透明度,范围从0(完全透明)到1(完全不透明)。 402 | """ 403 | # 将原图从BGR转换为RGBA以添加Alpha通道 404 | image_rgba = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA) 405 | # 创建一个同样大小的全蓝色图层 406 | blue_mask = np.zeros_like(image_rgba) 407 | blue_mask[:, :, 0] = 255 # B 408 | blue_mask[:, :, 3] = 255 # Alpha设置为不透明 409 | 410 | # 应用透明度到mask区域 411 | blue_mask[mask_ != 0, 3] = int(alpha * 255) 412 | 413 | # 将蓝色遮罩叠加到原图 414 | image_rgba = cv2.addWeighted(image_rgba, 1, blue_mask, alpha, 0) 415 | return image_rgba 416 | 417 | 418 | 419 | 420 | for i, img_id in enumerate(img_ids): 421 | ref = refer.imgToRefs[img_id][0] 422 | print(ref) 423 | mask_ = refer.getMask(ref)['mask'] 424 | # sentence = ref['sentences'][0]['sent'] 425 | 426 | img = refer.Imgs[img_id] 427 | # I = io.imread(osp.join(refer.IMAGE_DIR, img['file_name'])) 428 | # 假设`image_path`是原始图像的路径,`mask`是一个与原图像相同大小的二值数组 429 | image_path = osp.join(refer.IMAGE_DIR, img['file_name']) 430 | image = cv2.imread(image_path) 431 | # mask = np.zeros(image.shape[:2], dtype=np.uint8) # 这里你需要有一个实际的mask 432 | 433 | # custom_vis1(image, mask_) 434 | 435 | image = custom_vis2(image, mask_) 436 | 437 | # custom_vis3(image, mask_) 438 | 439 | # image = custom_vis4(image=image, mask_=mask_, alpha=0.4) 440 | 441 | 442 | 443 | 444 | # 保存结果图像到指定路径 445 | image_dir = osp.join(save_path, str(img_id)) 446 | osp.exists(image_dir) or os.makedirs(image_dir) 447 | # 复制原图 448 | I = io.imread(osp.join(refer.IMAGE_DIR, img['file_name'])) 449 | io.imsave(osp.join(image_dir, img['file_name']), I) 450 | 451 | 452 | cv2.imwrite(osp.join(image_dir, str(img_id)+".png"), image) 453 | 454 | # 将json格式的ref保存 455 | with open(osp.join(image_dir, str(img_id)+".json"), "w") as f: 456 | json.dump(ref, f) 457 | 458 | 459 | 460 | 461 | 462 | # i = 0 463 | # for ref_id in ref_ids: 464 | # i += 1 465 | # ref = refer.loadRefs(ref_id)[0] 466 | # if len(ref['sentences']) < 2: 467 | # continue 468 | 469 | # print(ref) 470 | # print('The label is %s.' % refer.Cats[ref['category_id']]) 471 | # plt.figure() 472 | # # refer.getMask(ref) 473 | # refer.showMask(ref) 474 | 475 | # # refer.showRef(ref, seg_box='seg') 476 | 477 | # plt.show() 478 | # if i == 0: 479 | # break 480 | # # save 481 | # plt.savefig('tmp.png') 482 | 483 | # plt.figure() 484 | # refer.showMask(ref) 485 | # plt.show() -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | import time 2 | import matplotlib as mpl 3 | mpl.use('Agg') 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.parallel 8 | import torch.optim 9 | from torch.autograd import Variable 10 | from torch.cuda.amp import autocast as autocast 11 | 12 | from model.model import * 13 | from dataset.data_loader import * 14 | from utils.losses import * 15 | from utils.parsing_metrics import * 16 | from utils.utils import * 17 | from utils.utils import dice_loss, sigmoid_focal_loss 18 | 19 | use_cuda = torch.cuda.is_available() 20 | print("use_cuda, ", use_cuda) 21 | 22 | 23 | def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger): 24 | print('train at epoch %d'%epoch) 25 | batch_time = AverageMeter() 26 | losses = AverageMeter() 27 | dice_losses = AverageMeter() 28 | sigmoid_focal_losses = AverageMeter() 29 | cos_losses = AverageMeter() 30 | model.train() 31 | end = time.time() 32 | 33 | for batch_idx, (imgs, word_id, word_mask, bbox, seg_map) in enumerate(train_loader): 34 | imgs = imgs.cuda(rank, non_blocking=True) 35 | word_id = word_id.cuda(rank, non_blocking=True) 36 | word_mask = word_mask.cuda(rank, non_blocking=True) 37 | seg_map = seg_map.cuda(rank, non_blocking=True) 38 | image = Variable(imgs) 39 | word_id = Variable(word_id) 40 | word_mask = Variable(word_mask) 41 | seg_map = Variable(seg_map) 42 | 43 | with autocast(): 44 | mask_out = model(image, word_id, word_mask) 45 | loss = 0. 46 | 47 | mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208] 48 | seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208] 49 | seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh) 50 | 51 | dice_loss_ = dice_loss(mask_out, seg_map) 52 | sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map) 53 | 54 | loss += dice_loss_ + sigmoid_focal_loss_ 55 | 56 | optimizer.zero_grad() 57 | scaler.scale(loss).backward() 58 | scaler.step(optimizer) 59 | scaler.update() 60 | 61 | losses.update(loss.item(), imgs.size(0)) 62 | dice_losses.update(dice_loss_.item(), imgs.size(0)) 63 | sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), imgs.size(0)) 64 | cos_losses.update(seg_iou.mean().item(), imgs.size(0)) 65 | 66 | # measure elapsed time 67 | batch_time.update(time.time() - end) 68 | end = time.time() 69 | 70 | if rank == 0 and batch_idx % args.print_freq == 0: 71 | print_str = 'Epoch: [{0}][{1}/{2}]\t' \ 72 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 73 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ 74 | 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \ 75 | 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \ 76 | 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \ 77 | .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses) 78 | print(print_str) 79 | logger.info(print_str) 80 | 81 | return losses.avg 82 | 83 | def validate_epoch(args, val_loader, model, logger, mode='val'): 84 | print('begin test') 85 | batch_time = AverageMeter() 86 | miou = AverageMeter() 87 | miou_seg = AverageMeter() 88 | 89 | prec=dict() 90 | thresholds = np.arange(0.5, 1, 0.05) 91 | 92 | for thresh in thresholds: 93 | prec[thresh]= AverageMeter() 94 | 95 | model.eval() 96 | end = time.time() 97 | idx = 0 98 | 99 | t_all = [] 100 | 101 | for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader): 102 | 103 | imgs = imgs.cuda(0) 104 | word_id = word_id.cuda(0) 105 | word_mask = word_mask.cuda(0) 106 | seg_map = seg_map.cuda(0) 107 | image = Variable(imgs) 108 | word_id = Variable(word_id) 109 | word_mask = Variable(word_mask) 110 | seg_map = Variable(seg_map) 111 | 112 | t1 = time.time() 113 | with torch.no_grad(): 114 | mask_out = model(image, word_id, word_mask) 115 | mask_out = mask_out.sigmoid() 116 | 117 | t2 = time.time() 118 | t_all.append(t2-t1) 119 | 120 | ## test: convert pred, gt box to original scale with meta-info 121 | ih = seg_map.shape[-2] 122 | iw = seg_map.shape[-1] 123 | nh = int(ih * ratio) 124 | nw = int(iw * ratio) 125 | top, bottom = int(dh[0]), nh + int(dh[0]) 126 | left, right = int(dw[0]), nw + int(dw[0]) 127 | ratio = float(ratio) 128 | new_shape = (iw, ih) 129 | 130 | ## revert image for visualization 131 | seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0) 132 | seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC) 133 | img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0) 134 | img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC) 135 | 136 | img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0)) 137 | 138 | # seg 139 | mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0) 140 | mask_out = cv2.resize(mask_out, (args.size, args.size)) 141 | mask_out_np = mask_out[top:bottom, left:right] 142 | mask_out_np = cv2.resize(mask_out_np, new_shape) 143 | seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh) 144 | miou_seg.update(seg_iou, imgs.size(0)) 145 | for thresh in thresholds: 146 | prec[thresh].update(seg_prec[thresh], imgs.size(0)) 147 | 148 | # measure elapsed time 149 | batch_time.update(time.time() - end) 150 | end = time.time() 151 | if batch_idx % 1000 == 0: 152 | print_str = '[{0}/{1}]\t' \ 153 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 154 | 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \ 155 | .format( \ 156 | batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg) 157 | print(print_str) 158 | logger.info(print_str) 159 | idx = idx + 1 160 | 161 | print(miou_seg.avg) 162 | for thresh in thresholds: 163 | print("prec@%f: %f"%(thresh,float(prec[thresh].avg))) 164 | logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg))) 165 | logger.info("%f,%f"%(float(miou.avg), miou_seg.avg)) 166 | return miou_seg.avg, prec 167 | 168 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from .modules import ConvBatchNormReLU, SFA 7 | from .modules import * 8 | from .position_encoding import * 9 | 10 | import clip 11 | import math 12 | import sys 13 | 14 | sys.path.append('../') 15 | from utils.utils import * 16 | 17 | 18 | class Simple_fusion(nn.Module): 19 | def __init__(self, visual_dim=1024, text_dim=768, proj_dim=1024, jemb_drop_out=0.1, leaky=True): 20 | super(Simple_fusion, self).__init__() 21 | self.proj_dim = proj_dim 22 | self.mapping_visu = ConvBatchNormReLU(visual_dim, proj_dim, 1, 1, 0, 1, leaky=leaky) 23 | self.lang_attn = nn.Sequential( 24 | nn.Linear(text_dim, text_dim), 25 | nn.Tanh(), 26 | nn.Dropout(jemb_drop_out), 27 | nn.Softmax(dim=1)) 28 | 29 | self.lang_proj = nn.Sequential( 30 | nn.Linear(text_dim, proj_dim), 31 | nn.BatchNorm1d(proj_dim), 32 | nn.LeakyReLU(0.1)) 33 | 34 | self.fusion = nn.Sequential( 35 | nn.BatchNorm2d(proj_dim), 36 | nn.LeakyReLU(0.1)) 37 | 38 | def forward(self, visual_feat, lang_feat): 39 | # visual proj 40 | visual_feat_proj = self.mapping_visu(visual_feat) # [bt, 1024, 13, 13] 41 | 42 | """ 43 | # lang attn 44 | lang_feat_attn = self.lang_attn(lang_feat) #[bt, 15, 768] 45 | lang_feat_new = lang_feat * lang_feat_attn 46 | lang_feat_new = lang_feat_new.sum(dim=1) #[bt, 768] 47 | """ 48 | 49 | lang_feat = lang_feat.squeeze(1) 50 | # lang proj 51 | #lang_feat_new = self.lang_proj(lang_feat_new) #[bt, 1024] 52 | lang_feat_new = self.lang_proj(lang_feat) #[bt, 1024] 53 | 54 | # fusion 55 | h, w = visual_feat.shape[-2], visual_feat.shape[-1] 56 | lang_feat_new_tile = lang_feat_new.view(-1, self.proj_dim, 1, 1).repeat(1, 1, h, w) # [bt, 1024, 13, 13] 57 | fusion_feat = lang_feat_new_tile * visual_feat_proj 58 | fusion_feat = self.fusion(fusion_feat) 59 | return fusion_feat 60 | 61 | class up_proj_cat_proj(nn.Module): 62 | def __init__(self, input_1, input_2, do=512, leaky=True): 63 | super(up_proj_cat_proj, self).__init__() 64 | self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky) 65 | self.proj2 = ConvBatchNormReLU(input_1+input_2, do, 1, 1, 0, 1, leaky=leaky) 66 | 67 | def forward(self, x, y): 68 | x = F.interpolate(x, scale_factor=2, mode='nearest') 69 | y = self.proj1(y) 70 | out = torch.cat([x,y], dim=1) 71 | out = self.proj2(out) 72 | return out 73 | 74 | class pool_proj_cat_proj(nn.Module): 75 | def __init__(self, input_1, input_2, do=512, leaky=True): 76 | super(pool_proj_cat_proj, self).__init__() 77 | self.downsample = nn.AvgPool2d(2, 2) 78 | self.proj1 = ConvBatchNormReLU(input_2, do // 2, 1, 1, 0, 1, leaky=leaky) 79 | self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky) 80 | self.proj3 = ConvBatchNormReLU(input_1+do, do, 1, 1, 0, 1, leaky=leaky) 81 | 82 | def forward(self, x, y): 83 | y = self.downsample(y) 84 | y = self.proj1(y) 85 | y = self.proj2(y) 86 | output = self.proj3(torch.cat([x,y], dim=1)) 87 | return output 88 | 89 | class proj_cat_proj(nn.Module): 90 | def __init__(self, input_1, input_2, do=512, leaky=True): 91 | super(proj_cat_proj, self).__init__() 92 | self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky) 93 | self.proj2 = ConvBatchNormReLU(input_1 + input_2, do, 1, 1, 0, 1, leaky=leaky) 94 | 95 | def forward(self, x, y): 96 | y = self.proj1(y) 97 | out = torch.cat([x, y], dim=1) 98 | out = self.proj2(out) 99 | return out 100 | 101 | class proj_cat(nn.Module): 102 | def __init__(self, input_1, input_2, do=512, leaky=True): 103 | super(proj_cat, self).__init__() 104 | self.proj1 = ConvBatchNormReLU(input_1, do // 2, 1, 1, 0, 1, leaky=leaky) 105 | self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky) 106 | 107 | def forward(self, x, y): 108 | x = self.proj1(x) 109 | x = self.proj2(x) 110 | output = torch.cat([x,y], dim=1) 111 | return output 112 | 113 | class mask_decoder(nn.Module): 114 | def __init__(self, input_1, seg_out_stride=2, leaky=True): 115 | super(mask_decoder, self).__init__() 116 | self.proj1 = ConvBatchNormReLU(input_1, input_1//2, 3, 1, 1, 1, leaky=leaky) 117 | self.proj2 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) 118 | 119 | self.proj3 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) 120 | self.proj4 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) 121 | self.proj5 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) 122 | #self.proj = nn.Conv2d(input_1, 1, 3, 1, 1, 1) 123 | self.proj = nn.Conv2d(input_1//2, 32, 3, 1, 1, 1) 124 | 125 | def forward(self, x, seg_out_stride): 126 | x = self.proj1(x) 127 | x = self.proj2(x) 128 | 129 | 130 | if seg_out_stride <= 8: 131 | x = F.interpolate(x, scale_factor=2, mode='nearest') 132 | x = self.proj3(x) 133 | 134 | if seg_out_stride <= 4: 135 | x = F.interpolate(x, scale_factor=2, mode='nearest') 136 | x = self.proj4(x) 137 | 138 | if seg_out_stride <= 2: 139 | x = F.interpolate(x, scale_factor=2, mode='nearest') 140 | x = self.proj5(x) 141 | 142 | x = self.proj(x) 143 | 144 | return x 145 | 146 | 147 | # class FeatureSelector(nn.Module): 148 | # def __init__(self, img_feature_dim, text_feature_dim, output_dim): 149 | # super(FeatureSelector, self).__init__() 150 | # # 使用nn.Sequential来简化MLP的构建 151 | # self.mlp = nn.Sequential( 152 | # nn.Linear(img_feature_dim * 3 + text_feature_dim * 3, 1024), 153 | # nn.ReLU(), 154 | # nn.Linear(1024, 256), 155 | # nn.ReLU(), 156 | # nn.Linear(256, output_dim) 157 | # ) 158 | 159 | # def forward(self, img_features, text_feature): 160 | # # 将图像特征和文本特征拼接 161 | # combined_features = torch.cat(img_features + text_feature, dim=1) # 162 | # # 通过MLP得到输出得分 163 | # scores = self.mlp(combined_features) 164 | # return scores 165 | 166 | 167 | class QuickGELU(nn.Module): 168 | def forward(self, x: torch.Tensor): 169 | return x * torch.sigmoid(1.702 * x) 170 | 171 | class ResidualAttentionblk(nn.Module): 172 | def __init__(self, clip_module): 173 | super().__init__() 174 | 175 | self.clip_module = clip_module 176 | 177 | self.selected_tokens = int(676 * 0.8) 178 | 179 | #self.norm = nn.LayerNorm(768) 180 | 181 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None, lang_tokens=None, index=0): 182 | 183 | 184 | if lang_tokens is None: 185 | x = x + self.clip_module.attention(self.clip_module.ln_1(x)) 186 | else: 187 | 188 | #if index >= 4 and index <= 7: 189 | # self.selected_tokens = int (676 * 0.8) 190 | #elif index>=8 and index <=11: 191 | # self.selected_tokens = int (676 * 0.5) 192 | #print(index) 193 | #print(self.selected_tokens) 194 | 195 | N, B, C = x.shape # N x B x C 196 | cls_x = x[:1, :, :] # 1 x B x C 197 | x = x[1:, :, :] # M x B x C 198 | 199 | ###img_cls text_cls 200 | #x = torch.mul(x, cls_x) 201 | #x = self.norm(x.reshape((N-1)*B, C)) 202 | #x = x.reshape(N-1, B, C) 203 | 204 | ### text eos token 205 | #score = torch.bmm(x.transpose(0,1), lang_tokens).squeeze(-1) 206 | 207 | ### text features mean 208 | score = torch.bmm(x.transpose(0, 1), lang_tokens.permute(1, 2, 0)).mean(dim=-1) # B x N 209 | score = score.transpose(0, 1) # N x B 210 | 211 | sorted_scores, sorted_indices = torch.sort(score, descending=True, dim=0) 212 | 213 | # high_mask = sorted_scores > sorted_scores[self.selected_tokens:self.selected_tokens+1, :] 214 | high_mask = torch.ones_like(sorted_scores) 215 | for i in range(B): 216 | high_mask[sorted_indices[self.selected_tokens:, i], i] = 0 217 | high_mask = high_mask > 0.5 218 | 219 | delta_x = x[high_mask].reshape(-1, B, C) # M x B x C 220 | low_x = x[~high_mask].reshape(-1, B, C) # N-M x B x C 221 | low_score = score[~high_mask].reshape(-1, B, 1) # N-M x B x 1 222 | 223 | low_x = low_x * torch.softmax(low_score, dim=0) # N-M x B x C 224 | low_x = low_x.sum(dim=0, keepdim=True) # 1 x B x C 225 | 226 | delta_x = torch.cat([cls_x, delta_x, low_x], dim=0) # M+1 x B x C 227 | delta_x = self.clip_module.attention(self.clip_module.ln_1(delta_x)) 228 | 229 | # for i in range(B): 230 | # x[high_mask[:, i], i, :] += delta_x[1:-1, i, :] 231 | # x[~high_mask[:, i], i, :] += delta_x[-1:, i, :] 232 | # cls_x[:, i] += delta_x[:1, i, :] 233 | temple = torch.zeros_like(x).type(delta_x.type()) 234 | temple[high_mask] = delta_x[1:-1, :, :].reshape(-1, C) 235 | temple[~high_mask] = delta_x[-1:, :, :].reshape(-1, 1, C).repeat(1, 676 - self.selected_tokens, 1).reshape(-1, C) 236 | x = x + temple 237 | cls_x = cls_x + delta_x[:1, :, :] 238 | 239 | x = torch.cat([cls_x, x], dim=0) 240 | 241 | x = x + self.clip_module.mlp(self.clip_module.ln_2(x)) 242 | return x 243 | 244 | class Model(nn.Module): 245 | def __init__(self, clip_model='RN50', tunelang=False, fusion_dim=2048, num_query=16, do=512, leaky=True, length=17): 246 | super(Model, self).__init__() 247 | 248 | self.tunelang = tunelang 249 | self.length = length 250 | 251 | ## Init Encoders 252 | clip_models = clip.load(clip_model, jit=False, device=torch.device("cpu"))[0].cuda() 253 | 254 | self.visumodel = clip_models.visual 255 | self.visu_dim = 768 256 | 257 | self.cut_list = [] 258 | self.visu_resblocks = nn.ModuleList([ResidualAttentionblk(self.visumodel.transformer.resblocks[i]) for i in range(12)]) 259 | self.visu_proj = nn.ModuleList([nn.Linear(do, self.visu_dim) for _ in range(len(self.cut_list))]) 260 | 261 | self.positional_embedding = nn.Parameter(torch.FloatTensor(1, 26 ** 2 + 1, 768)) 262 | v = self.resize_pos_embed(self.visumodel.positional_embedding.data.unsqueeze(0), self.positional_embedding, 26, 26) 263 | self.positional_embedding.data.copy_(v) 264 | 265 | self.textmodel = clip_models.transformer 266 | self.textmodel_token_embedding = clip_models.token_embedding 267 | self.textmodel_pos_embed = nn.Parameter(clip_models.positional_embedding[:self.length, :].unsqueeze(0)) 268 | self.textmodel_ln_final = clip_models.ln_final 269 | self.textdim = self.textmodel_pos_embed.shape[-1] 270 | for module in self.textmodel.resblocks: 271 | module.attn_mask = self.build_attention_mask() 272 | 273 | # vis select 274 | self.vis_select = nn.Linear(self.visu_dim, do, bias=False) 275 | 276 | ## Fusion 277 | 278 | # fusion with x12 279 | self.fusion = Simple_fusion(visual_dim=self.visu_dim, text_dim=self.textdim, proj_dim=fusion_dim) 280 | 281 | # fusion with x6 282 | self.up_proj_cat_proj_1 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=fusion_dim) 283 | self.pool_proj_cat_proj_2 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=do) 284 | 285 | # fusion with x9 286 | self.proj_cat = proj_cat(input_1=fusion_dim, input_2=do, do=do) 287 | self.up_proj_cat_2 = proj_cat_proj(input_1=fusion_dim, input_2=do * 2, do=do) 288 | self.proj_0 = ConvBatchNormReLU(do, do, 1, 1, 0, 1, leaky=leaky) 289 | 290 | self.fpn = SFA(in_channels=self.visu_dim, out_channels=do) 291 | 292 | ## Align dim 293 | f_dim = 512 294 | self.fc_2 = nn.Linear(f_dim, f_dim, bias=False) 295 | self.norm1 = nn.LayerNorm(f_dim) 296 | self.norm2 = nn.LayerNorm(f_dim) 297 | 298 | # visual branch 299 | self.pos_embedding = PositionEmbeddingSine(f_dim) 300 | encoder_layer = TransformerEncoderLayer(f_dim, nhead=8, dim_feedforward=f_dim, 301 | dropout=0.1, activation='relu', normalize_before=False) 302 | self.encoder = TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(f_dim)) 303 | 304 | ## Decoder 305 | self.mask_decoder = mask_decoder(f_dim, seg_out_stride=2) 306 | 307 | # text branch 308 | 309 | ## coef 310 | self.lang_tf_enc = lang_tf_enc(do, do, do, head_num=8) 311 | self.proj1 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky) 312 | self.proj2 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky) 313 | self.proj3 = nn.Conv2d(do, 32, 3, 1, 1, 1) 314 | self.projout = nn.Linear(26*26*32, 32, bias=False) 315 | 316 | 317 | self.feature_selector_l = nn.Linear(do, 1, bias=True) 318 | self.feature_selector_m = nn.Linear(do, 1, bias=True) 319 | 320 | def resize_pos_embed(self, posemb, posemb_new, hight, width): 321 | ntok_new = posemb_new.shape[1] 322 | 323 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 324 | ntok_new -= 1 325 | 326 | gs_old = int(math.sqrt(len(posemb_grid))) 327 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 328 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 329 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 330 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 331 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 332 | return posemb 333 | 334 | 335 | def build_attention_mask(self): 336 | # lazily create causal attention mask, with full attention between the vision tokens 337 | # pytorch uses additive attention mask; fill with -inf 338 | mask = torch.empty(self.length, self.length) 339 | mask.fill_(float("-inf")) 340 | mask.triu_(1) # zero out the lower diagonal 341 | return mask 342 | 343 | def forward(self, image, word_id, word_mask): 344 | ## Visual Module 345 | 346 | batch_size = image.size(0) 347 | 348 | # Extract features from vision 349 | x = self.visumodel.conv1(image) 350 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 351 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 352 | x = torch.cat([self.visumodel.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 353 | x = x + self.positional_embedding.to(x.dtype) 354 | x = self.visumodel.ln_pre(x) 355 | x = x.permute(1, 0, 2) # NLD -> LND 356 | 357 | raw_fword = self.textmodel_token_embedding(word_id).squeeze(1) 358 | raw_fword = raw_fword + self.textmodel_pos_embed 359 | raw_fword = raw_fword.permute(1, 0, 2) # NLD -> LND 360 | 361 | visu_list_l = [] 362 | visu_list_m = [] 363 | 364 | scores_l = [] 365 | scores_m = [] 366 | 367 | for i, [blk_visu, blk_lang] in enumerate(zip(self.visu_resblocks, self.textmodel.resblocks)): 368 | x = blk_visu(x) # [677, bs, 768] 369 | raw_fword = blk_lang(raw_fword) 370 | 371 | img_cls = self.vis_select(x[0, :, :]) # [B, C] 372 | tex_cls = raw_fword[word_id.argmax(dim=-1).reshape(-1), torch.arange(raw_fword.shape[1]), :] # [B, C] 373 | score = img_cls * tex_cls # [B, C] 374 | score = score.unsqueeze(1) # [B, 1, C] 375 | 376 | if i >=3 and i <= 5: 377 | visu_list_l.append(x) 378 | scores_l.append(score) 379 | 380 | if i>=6 and i <=8: 381 | visu_list_m.append(x) 382 | scores_m.append(score) 383 | 384 | 385 | scores_l = torch.cat(scores_l, dim=1) # [B, 3, C] 386 | scores_m = torch.cat(scores_m, dim=1) # [B, 3, C] 387 | 388 | scores_l = self.feature_selector_l(scores_l).squeeze(-1) # [B, 3] 389 | scores_l = F.softmax(scores_l, dim=-1) 390 | scores_m = self.feature_selector_m(scores_m).squeeze(-1) # [B, 3] 391 | scores_m = F.softmax(scores_m, dim=-1) 392 | 393 | visu_list_l = torch.cat(visu_list_l, dim=0).reshape(len(visu_list_l), -1, batch_size, self.visu_dim).permute(0,2,1,3) 394 | visu_list_m = torch.cat(visu_list_m, dim=0).reshape(len(visu_list_m), -1, batch_size, self.visu_dim).permute(0,2,1,3) 395 | 396 | 397 | x6 = visu_list_l[scores_l.argmax(dim=-1).reshape(-1), torch.arange(visu_list_l.shape[1]), :, :].permute(1,0,2) 398 | x9 = visu_list_m[scores_m.argmax(dim=-1).reshape(-1), torch.arange(visu_list_m.shape[1]), :, :].permute(1,0,2) 399 | 400 | 401 | x6 = x6.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) 402 | x9 = x9.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) 403 | x12 = x.permute(1, 0, 2)[:, 1:, :] 404 | x12 = x12.reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) # [bs, 768, 26, 26] 405 | 406 | 407 | raw_fword = raw_fword.permute(1, 0, 2) 408 | raw_fword = self.textmodel_ln_final(raw_fword) 409 | 410 | if not self.tunelang: 411 | raw_fword = raw_fword.detach() 412 | 413 | eos_token = raw_fword[torch.arange(raw_fword.shape[0]), word_id.argmax(dim=-1).reshape(-1), :] 414 | 415 | F_g = self.fusion(x12, eos_token) 416 | F_tf = self.fpn([F_g, x9, x6]) 417 | 418 | # Main body 419 | b, c, h, w = F_tf.shape 420 | 421 | flatten_length = h*w 422 | visu_feat = F_tf.reshape(b, c, flatten_length) 423 | visu_feat = F.relu(visu_feat) 424 | lang_feat = F.relu(self.fc_2(raw_fword)) 425 | 426 | visu_feat = visu_feat.permute(0, 2, 1) 427 | pos_embed = self.pos_embedding(visu_feat) 428 | visu_feat = visu_feat.transpose(0, 1) 429 | pos_embed = pos_embed.transpose(0, 1) 430 | visu_feat = self.encoder(visu_feat, pos=pos_embed) 431 | #[HW B C] 432 | 433 | visu_feat_ = visu_feat.permute(1,0,2) 434 | 435 | # mask decoder 436 | visu_feat = visu_feat.reshape(h, w, b, c) 437 | visu_feat = visu_feat.permute(2,3,0,1) 438 | proto_masks = self.mask_decoder(visu_feat, 2) 439 | 440 | #[B C H W] 441 | proto_masks = F.relu(proto_masks) 442 | 443 | # coef 444 | coef = self.lang_tf_enc(visu_feat_, lang_feat) 445 | coef = coef.view(b, h, w, c) 446 | coef = coef.permute(0, 3, 1, 2) 447 | 448 | coef = self.proj1(coef) 449 | coef = self.proj2(coef) 450 | coef = self.proj3(coef) 451 | coef = coef.permute(0, 2, 3, 1) 452 | coef = coef.contiguous().view(b, h*w*32) 453 | # [b, 1, 32] 454 | coef = self.projout(coef).unsqueeze(-1) 455 | coef = F.tanh(coef) 456 | 457 | # mask assemble 458 | proto_masks = proto_masks.permute(0, 2, 3, 1) 459 | proto_masks = proto_masks.view(b, -1, 32) 460 | #[B HW N] [32 208*208 32] 461 | 462 | mask_out = torch.bmm(proto_masks, coef, out=None) 463 | mask_out = mask_out.view(b, 208, 208, 1) 464 | mask_out = mask_out.permute(0, 3, 1, 2) 465 | return mask_out 466 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .transformer import lang_tf_enc, TransformerEncoderLayer, TransformerEncoder 6 | from .position_encoding import PositionEmbeddingSine 7 | 8 | class SFA(nn.Module): 9 | def __init__(self, in_channels, out_channels, scale_factors = [1, 2, 4], fuse_type="sum"): 10 | super(SFA, self).__init__() 11 | self.stages = [] 12 | for idx, scale in enumerate(scale_factors): 13 | out_dim = out_channels 14 | if scale == 4.0: 15 | layers = [ 16 | nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2), 17 | nn.BatchNorm2d( 18 | num_features=in_channels // 2, eps=1e-5, momentum=0.999, affine=True), 19 | nn.GELU(), 20 | nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2), 21 | ] 22 | out_dim = in_channels // 4 23 | elif scale == 2.0: 24 | layers = [nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)] 25 | out_dim = in_channels // 2 26 | elif scale == 1.0: 27 | layers = [] 28 | out_dim = in_channels 29 | elif scale == 0.5: 30 | layers = [nn.MaxPool2d(kernel_size=2, stride=2)] 31 | else: 32 | raise NotImplementedError(f"scale_factor={scale} is not supported yet.") 33 | 34 | layers.extend( 35 | [ 36 | ConvBatchNormReLU(out_dim, out_channels, 1, 1, 0, 1, leaky=True), 37 | ConvBatchNormReLU(out_channels, out_channels, 3, 1, 1, 1, leaky=True), 38 | ] 39 | ) 40 | layers = nn.Sequential(*layers) 41 | self.stages.append(layers) 42 | 43 | self.stages = nn.ModuleList(self.stages) 44 | 45 | # 假设所有输入特征图的通道数相同 46 | self.lateral_convs = nn.ModuleList([ 47 | ConvBatchNormReLU(out_channels, out_channels, 1, 1, 0, 1, leaky=True) for _ in range(3) 48 | ]) 49 | 50 | self.output_convs = nn.ModuleList([ 51 | ConvBatchNormReLU(out_channels, out_channels, 3, 1, 1, 1, leaky=True) for _ in range(3) 52 | ]) 53 | 54 | self._fuse_type = fuse_type # or "avg" 55 | 56 | self.downsample = nn.MaxPool2d(kernel_size=4, stride=4, padding=0) 57 | 58 | def forward(self, x): 59 | ''' 60 | Args: 61 | x: list[Tensor], T个特征图,每个特征图的尺寸和通道数相同,[x12, x9, x6] 62 | ''' 63 | # 模拟bottom-up, 获取多尺度特征图 64 | mutil_scale_features = [] 65 | for idx, stage in enumerate(self.stages): 66 | mutil_scale_features.append(stage(x[idx])) 67 | 68 | # top-down 69 | results = [] 70 | prev_features = self.lateral_convs[0](mutil_scale_features[0]) 71 | 72 | for idx, (lateral_conv, output_conv) in enumerate( 73 | zip(self.lateral_convs, self.output_convs) 74 | ): 75 | # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336 76 | # Therefore we loop over all modules but skip the first one 77 | if idx > 0: 78 | features = mutil_scale_features[idx] 79 | top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") 80 | lateral_features = lateral_conv(features) # 1x1卷积 81 | prev_features = lateral_features + top_down_features 82 | if self._fuse_type == "avg": 83 | prev_features /= 2 84 | results.insert(0, output_conv(prev_features)) 85 | 86 | fused_features = self.downsample(results[0]) # 1/4分辨率,需要转换为1/16分辨率 87 | 88 | return fused_features 89 | 90 | class ConvBatchNormReLU(nn.Module): 91 | def __init__( 92 | self, 93 | in_channels, 94 | out_channels, 95 | kernel_size, 96 | stride, 97 | padding, 98 | dilation, 99 | leaky=False, 100 | relu=True, 101 | instance=False, 102 | ): 103 | super(ConvBatchNormReLU, self).__init__() 104 | self.conv = nn.Conv2d( 105 | in_channels=in_channels, 106 | out_channels=out_channels, 107 | kernel_size=kernel_size, 108 | stride=stride, 109 | padding=padding, 110 | dilation=dilation, 111 | bias=False) 112 | # nn.init.kaiming_normal_(self.conv.weight, mode="fan_out", nonlinearity="leaky_relu" if leaky else "relu") 113 | 114 | if instance: 115 | self.bn = nn.InstanceNorm2d(num_features=out_channels) 116 | else: 117 | self.bn = nn.BatchNorm2d( 118 | num_features=out_channels, eps=1e-5, momentum=0.999, affine=True 119 | ) 120 | 121 | if leaky: 122 | self.relu = nn.LeakyReLU(0.1) 123 | elif relu: 124 | self.relu = nn.ReLU() 125 | def forward(self, x): 126 | x = self.conv(x) 127 | x = self.bn(x) 128 | x = self.relu(x) 129 | return x 130 | 131 | # class ConvBatchNormReLU(nn.Sequential): 132 | # def __init__( 133 | # self, 134 | # in_channels, 135 | # out_channels, 136 | # kernel_size, 137 | # stride, 138 | # padding, 139 | # dilation, 140 | # leaky=False, 141 | # relu=True, 142 | # instance=False, 143 | # ): 144 | # super(ConvBatchNormReLU, self).__init__() 145 | 146 | # conv = nn.Conv2d( 147 | # in_channels=in_channels, 148 | # out_channels=out_channels, 149 | # kernel_size=kernel_size, 150 | # stride=stride, 151 | # padding=padding, 152 | # dilation=dilation, 153 | # bias=False, 154 | # ) 155 | # nn.init.kaiming_normal_(conv.weight, mode="fan_out", nonlinearity="leaky_relu" if leaky else "relu") 156 | 157 | # self.add_module( 158 | # "conv", conv 159 | # ) 160 | 161 | # if instance: 162 | # self.add_module( 163 | # "bn", 164 | # nn.InstanceNorm2d(num_features=out_channels), 165 | # ) 166 | # else: 167 | # self.add_module( 168 | # "bn", 169 | # nn.BatchNorm2d( 170 | # num_features=out_channels, eps=1e-5, momentum=0.999, affine=True 171 | # ), 172 | # ) 173 | 174 | # if leaky: 175 | # self.add_module("relu", nn.LeakyReLU(0.1)) 176 | # elif relu: 177 | # self.add_module("relu", nn.ReLU()) 178 | 179 | # def forward(self, x): 180 | # return super(ConvBatchNormReLU, self).forward(x) 181 | 182 | 183 | def concat_coord(x): 184 | ins_feat = x # [bt, c, h, w] [512, 26, 26] 185 | batch_size, c, h, w = x.size() 186 | 187 | float_h = float(h) 188 | float_w = float(w) 189 | 190 | y_range = torch.arange(0., float_h, dtype=torch.float32) 191 | y_range = 2.0 * y_range / (float_h - 1.0) - 1.0 192 | x_range = torch.arange(0., float_w, dtype=torch.float32) 193 | x_range = 2.0 * x_range / (float_w - 1.0) - 1.0 194 | x_range = x_range[None, :] 195 | y_range = y_range[:, None] 196 | x = x_range.repeat(h, 1) 197 | y = y_range.repeat(1, w) 198 | 199 | x = x[None, None, :, :] 200 | y = y[None, None, :, :] 201 | x = x.repeat(batch_size, 1, 1, 1) 202 | y = y.repeat(batch_size, 1, 1, 1) 203 | x = x.cuda() 204 | y = y.cuda() 205 | 206 | ins_feat_out = torch.cat((ins_feat, x, x, x, y, y, y), 1) 207 | 208 | return ins_feat_out 209 | 210 | 211 | class query_generator(nn.Module): 212 | def __init__(self, input, output, leaky=True): 213 | super(query_generator, self).__init__() 214 | self.proj1 = ConvBatchNormReLU(input+6, input+6, 3, 1, 1, 1, leaky=leaky) 215 | self.proj2 = ConvBatchNormReLU(input+6, input+6, 3, 1, 1, 1, leaky=leaky) 216 | self.proj3 = ConvBatchNormReLU(input+6, input+6, 3, 1, 1, 1, leaky=leaky) 217 | self.proj = nn.Conv2d(input+6, output, 1, 1, 0, 1) 218 | 219 | def forward(self, x): 220 | x = concat_coord(x) 221 | x = x + self.proj1(x) 222 | x = x + self.proj2(x) 223 | x = x + self.proj3(x) 224 | x = self.proj(x) 225 | return x 226 | 227 | 228 | class KLM(nn.Module): 229 | def __init__(self, f_dim, feat_dim): 230 | super(KLM, self).__init__() 231 | self.lang_tf_enc = lang_tf_enc(f_dim, f_dim, f_dim, head_num=8) 232 | 233 | self.pos_embedding = PositionEmbeddingSine(f_dim) 234 | encoder_layer = TransformerEncoderLayer(f_dim, nhead=8, dim_feedforward=f_dim, 235 | dropout=0.1, activation='relu', normalize_before=False) 236 | self.encoder = TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(f_dim)) 237 | 238 | # self.catproj = nn.Linear(f_dim * 2, f_dim) 239 | 240 | self.fc_ker = nn.Linear(f_dim, feat_dim + feat_dim) 241 | self.fc_vis = nn.Linear(f_dim, feat_dim + feat_dim) 242 | self.ker_norm = nn.LayerNorm(feat_dim) 243 | self.vis_norm = nn.LayerNorm(feat_dim) 244 | 245 | self.channel_fc = nn.Linear(feat_dim, feat_dim) 246 | self.channel_norm = nn.LayerNorm(feat_dim) 247 | 248 | self.spatial_fc = nn.Linear(feat_dim, feat_dim) 249 | self.spatial_norm = nn.LayerNorm(feat_dim) 250 | 251 | self.out_fc = nn.Linear(feat_dim, f_dim) 252 | self.out_norm = nn.LayerNorm(f_dim) 253 | 254 | self.d_model = f_dim 255 | self.feat_dim = feat_dim 256 | self.resolution_size = 26 257 | 258 | def forward(self, kernel, lang_feat, visu_feat): 259 | # kernel B x N x C 260 | # lang_feat B x T x C 261 | # visu_feat B x C x HW 262 | kernel = self.lang_tf_enc(kernel, lang_feat) 263 | # B x N x C 264 | bs, c, hw = visu_feat.shape 265 | bq, nq, cq = kernel.shape 266 | bl, ll, cl = lang_feat.shape 267 | 268 | # Image Attention 269 | visu_feat = visu_feat.permute(0, 2, 1) 270 | # B x HW x C 271 | pos_embed = self.pos_embedding(visu_feat) 272 | # B x HW x C 273 | 274 | visu_feat = visu_feat.transpose(0, 1) 275 | pos_embed = pos_embed.transpose(0, 1) 276 | visu_feat_ = self.encoder(visu_feat, pos=pos_embed) # HW x B x C 277 | visu_feat_ = visu_feat_.transpose(0, 1) # B x HW x C 278 | 279 | # repeat visual feats 280 | visu_feat = visu_feat_.unsqueeze(dim=1) # B x 1 x HW x C 281 | kernel = kernel.unsqueeze(dim=2) # B x N x 1 x C 282 | lang_feat = lang_feat.unsqueeze(dim=2) # B x Q x 1 x C 283 | 284 | kernel_in = self.fc_ker(kernel) 285 | kernel_out = kernel_in[:, :, :, self.feat_dim:] 286 | kernel_in = kernel_in[:, :, :, :self.feat_dim] 287 | 288 | vis_in = self.fc_vis(visu_feat) 289 | vis_out = vis_in[:, :, :, self.feat_dim:] 290 | vis_in = vis_in[:, :, :, :self.feat_dim] 291 | 292 | gate_feat = self.ker_norm(kernel_in) * self.vis_norm(vis_in) 293 | #[B N HW 64] 294 | 295 | channel_gate = self.channel_norm(self.channel_fc(gate_feat)) 296 | channel_gate = channel_gate.mean(2, keepdim=True) 297 | channel_gate = torch.sigmoid(channel_gate) 298 | # B x N x 1 x C 299 | 300 | spatial_gate = self.spatial_norm(self.spatial_fc(gate_feat)) 301 | # spatial_gate = spatial_gate.mean(3, keepdim=True) 302 | spatial_gate = torch.sigmoid(spatial_gate) 303 | # B x N x HW x C 304 | 305 | channel_gate = (1 + channel_gate) * kernel_out # B x N x 1 x C 306 | channel_gate = channel_gate.squeeze(2) # B x N x C 307 | 308 | spatial_gate = (1 + spatial_gate) * vis_out # B x N x HW x C 309 | spatial_gate = spatial_gate.mean(2) # B x N x C 310 | 311 | gate_feat = (channel_gate + spatial_gate) / 2 312 | # [B N 64] 313 | gate_feat = self.out_fc(gate_feat) 314 | gate_feat = self.out_norm(gate_feat) 315 | gate_feat = F.relu(gate_feat) 316 | #[B N C] 317 | 318 | #visu_feat_.transpose(1, 2) [B C HW] 319 | return gate_feat, visu_feat_.transpose(1, 2) 320 | 321 | 322 | class KAM(nn.Module): 323 | def __init__(self, f_dim, num_query): 324 | super(KAM, self).__init__() 325 | 326 | self.k_size = 1 327 | 328 | self.proj = nn.Linear(26*26, f_dim) 329 | 330 | self.fc_k = nn.Linear(f_dim, f_dim) 331 | self.fc_m = nn.Linear(f_dim, f_dim) 332 | self.fc_fus = nn.Linear(f_dim * 2, f_dim) 333 | self.fc_out = nn.Linear(f_dim, 1) 334 | 335 | self.outproj = ConvBatchNormReLU(num_query, f_dim, 3, 1, 1, 1, leaky=True) 336 | self.maskproj = nn.Conv2d(f_dim, 1, 3, 1, 1, 1) 337 | 338 | self.bn = nn.BatchNorm2d(f_dim) 339 | 340 | self.mask_fcs = [] 341 | for _ in range(3): 342 | self.mask_fcs.append(nn.Linear(f_dim, f_dim, bias=False)) 343 | self.mask_fcs.append(nn.LayerNorm(f_dim)) 344 | self.mask_fcs.append(nn.ReLU()) 345 | self.mask_fcs = nn.Sequential(*self.mask_fcs) 346 | 347 | 348 | def forward(self, kernel, visu_feat): 349 | # kernel [B N C] 350 | # visu_feat [B C HW] 351 | kernel = self.mask_fcs(kernel) 352 | 353 | B, N, C = kernel.shape 354 | kernel_ = kernel 355 | kernel = kernel.reshape(B, N, -1, C).permute(0, 1, 3, 2) # B x N x C x 1 356 | kernel = kernel.reshape(B, N, C, self.k_size, self.k_size) # B x N x C x 1 x 1 357 | #[B N C K K] 358 | visu_feat_ = visu_feat 359 | visu_feat = visu_feat.reshape(B, C, 26, 26) # B x C x H x W 360 | 361 | masks = [] 362 | for i in range(B): 363 | masks.append(F.conv2d(visu_feat[i: i+1], kernel[i], padding=int(self.k_size // 2))) # 1 x N x H x W 364 | masks = torch.cat(masks, dim=0) # B x N x H x W 365 | 366 | feats = masks.reshape(B, N, -1) # B x N x HW 367 | feats = self.proj(feats) # B x N x C 368 | 369 | weights_kern = F.relu(self.fc_k(kernel_)) 370 | weights_mask = F.relu(self.fc_m(feats)) 371 | 372 | weights = torch.cat([weights_kern, weights_mask], dim=-1) # B x N x 2C 373 | weights = F.relu(self.fc_fus(weights)) # B x N x C 374 | weights = self.fc_out(weights) # B x N x 1 375 | weights = F.softmax(weights, dim=1) # B x N x 1 376 | 377 | weights = weights.unsqueeze(-1) # B x N x 1 x 1 378 | 379 | mask = weights * masks # B x N x H x W 380 | mask = self.outproj(mask) # B x C x H x W 381 | mask = self.maskproj(mask) 382 | mask = F.sigmoid(mask) # B x 1 x H x W 383 | 384 | visu_feat = visu_feat * mask # B x C x H x W 385 | 386 | visu_feat = self.bn(visu_feat) 387 | visu_feat = visu_feat.reshape(B, C, -1) + visu_feat_ 388 | visu_feat = F.relu(visu_feat) 389 | return visu_feat 390 | 391 | -------------------------------------------------------------------------------- /model/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | 5 | Various positional encodings for the transformer. 6 | 7 | """ 8 | 9 | import math 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class PositionEmbeddingSine(nn.Module): 15 | """ 16 | This is a more standard version of the position embedding, very similar to the one 17 | used by the Attention is all you need paper, generalized to work on images. 18 | """ 19 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 20 | super().__init__() 21 | self.num_pos_feats = num_pos_feats // 2 22 | self.temperature = temperature 23 | self.normalize = normalize 24 | if scale is not None and normalize is False: 25 | raise ValueError("normalize should be True if scale is passed") 26 | if scale is None: 27 | scale = 2 * math.pi 28 | self.scale = scale 29 | 30 | def forward(self, f_s): 31 | not_mask = torch.ones_like(f_s[:, :, 0].reshape(-1, 26, 26)) 32 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 33 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=f_s.device) 40 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2,rounding_mode = 'floor')/ self.num_pos_feats) 41 | 42 | pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).reshape_as(f_s) 47 | return pos 48 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | Copy-paste from torch.nn.Transformer with modifications: 5 | * positional encodings are passed in MHattention 6 | * extra LN at the end of encoder is removed 7 | * decoder returns a stack of activations from all decoding layers 8 | """ 9 | import copy 10 | from typing import Optional 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, Tensor 14 | from .position_encoding import * 15 | 16 | 17 | class lang_tf_enc(nn.Module): 18 | 19 | def __init__(self, input_1, input_2, hidden_dim, head_num, dropout=0.1): 20 | super(lang_tf_enc, self).__init__() 21 | self.pos_embedding_1 = PositionEmbeddingSine(input_2, normalize=True) 22 | self.pos_embedding_2 = PositionEmbeddingSine(input_1, normalize=True) 23 | self.dense_q = nn.Linear(input_1, hidden_dim) 24 | self.dense_k = nn.Linear(input_2, hidden_dim) 25 | self.dense_v = nn.Linear(input_2, hidden_dim) 26 | self.self_attn = nn.MultiheadAttention(hidden_dim, head_num, dropout=dropout) 27 | 28 | self.forward_dim = 2048 29 | self.norm1 = nn.LayerNorm(hidden_dim) 30 | self.norm2 = nn.LayerNorm(hidden_dim) 31 | self.linear1 = nn.Linear(hidden_dim, self.forward_dim) 32 | self.linear2 = nn.Linear(self.forward_dim, hidden_dim) 33 | self.activation = _get_activation("relu") 34 | self.dropout = nn.Dropout(dropout) 35 | 36 | # @get_local("weights") 37 | def forward(self, vision_input, lang_input): 38 | decoder_embed_lang = lang_input 39 | decoder_embed_vis = vision_input 40 | q_inp = F.relu(self.dense_q(decoder_embed_vis).permute(1, 0, 2)) 41 | k_inp = F.relu(self.dense_k(decoder_embed_lang).permute(1, 0, 2)) 42 | v_inp = F.relu(self.dense_v(decoder_embed_lang).permute(1, 0, 2)) 43 | lang_input = lang_input.permute(1, 0, 2) 44 | decoded_layer, weights = self.self_attn(q_inp, k_inp, v_inp) 45 | 46 | decoded_layer = decoded_layer.permute(1, 0, 2) 47 | add_layer = decoded_layer + vision_input 48 | 49 | add_layer = self.norm1(add_layer) 50 | add_layer2 = self.linear2(self.dropout(self.activation(self.linear1(add_layer)))) 51 | add_layer = add_layer + self.dropout(add_layer2) 52 | add_layer = self.norm2(add_layer) 53 | 54 | return add_layer 55 | 56 | 57 | def _get_clones(module, N): 58 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 59 | 60 | def _get_activation(activation): 61 | 62 | if activation == "relu": 63 | return F.relu 64 | if activation == "gelu": 65 | return F.gelu 66 | if activation == "glu": 67 | return F.glu 68 | raise RuntimeError(F"activation shuld be relu/gelu, not {activation}.") 69 | 70 | 71 | class TransformerEncoder(nn.Module): 72 | 73 | def __init__(self, encoder_layer, num_layers, norm=None): 74 | super().__init__() 75 | self.layers = _get_clones(encoder_layer, num_layers) 76 | self.num_layers = num_layers 77 | self.norm = norm 78 | 79 | def forward(self, src, pos: Optional[Tensor] = None): 80 | output = src 81 | 82 | for layer in self.layers: 83 | output = layer(output, pos=pos) 84 | 85 | if self.norm is not None: 86 | output = self.norm(output) 87 | 88 | return output 89 | 90 | 91 | class TransformerDecoder(nn.Module): 92 | 93 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 94 | super().__init__() 95 | self.layers = _get_clones(decoder_layer, num_layers) 96 | self.num_layers = num_layers 97 | self.norm = norm 98 | self.return_intermediate = return_intermediate 99 | 100 | def forward(self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): 101 | output = tgt 102 | 103 | intermediate = [] 104 | 105 | for layer in self.layers: 106 | output = layer(output, memory, pos=pos, query_pos=query_pos) 107 | if self.return_intermediate: 108 | intermediate.append(self.norm(output)) 109 | 110 | 111 | if self.norm is not None: 112 | output = self.norm(output) 113 | if self.return_intermediate: 114 | intermediate.pop() 115 | intermediate.append(output) 116 | 117 | 118 | if self.return_intermediate: 119 | return torch.stack(intermediate) 120 | 121 | return output 122 | 123 | 124 | class TransformerEncoderLayer(nn.Module): 125 | 126 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 127 | activation="relu", normalize_before=False): 128 | super().__init__() 129 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 130 | # Implementation of Feedforward model 131 | self.linear1 = nn.Linear(d_model, dim_feedforward) 132 | self.dropout = nn.Dropout(dropout) 133 | self.linear2 = nn.Linear(dim_feedforward, d_model) 134 | 135 | self.norm1 = nn.LayerNorm(d_model) 136 | self.norm2 = nn.LayerNorm(d_model) 137 | self.dropout1 = nn.Dropout(dropout) 138 | self.dropout2 = nn.Dropout(dropout) 139 | 140 | self.activation = _get_activation_fn(activation) 141 | self.normalize_before = normalize_before 142 | 143 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 144 | return tensor if pos is None else tensor + pos 145 | 146 | # @get_local("weights") 147 | def forward_post(self, src, pos: Optional[Tensor] = None): 148 | q = k = self.with_pos_embed(src, pos) 149 | src2, weights = self.self_attn(q, k, value=src, need_weights=False) 150 | 151 | src = src + self.dropout1(src2) 152 | src = self.norm1(src) 153 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 154 | src = src + self.dropout2(src2) 155 | src = self.norm2(src) 156 | return src 157 | 158 | def forward_pre(self, src, pos: Optional[Tensor] = None): 159 | src2 = self.norm1(src) 160 | q = k = self.with_pos_embed(src2, pos) 161 | src2, weights = self.self_attn(q, k, value=src2) 162 | src = src + self.dropout1(src2) 163 | src2 = self.norm2(src) 164 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 165 | src = src + self.dropout2(src2) 166 | return src 167 | 168 | def forward(self, src, pos: Optional[Tensor] = None): 169 | if self.normalize_before: 170 | return self.forward_pre(src, pos) 171 | return self.forward_post(src, pos) 172 | 173 | 174 | class TransformerDecoderLayer(nn.Module): 175 | 176 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 177 | activation="relu", normalize_before=False): 178 | super().__init__() 179 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 180 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 181 | # Implementation of Feedforward model 182 | self.linear1 = nn.Linear(d_model, dim_feedforward) 183 | self.dropout = nn.Dropout(dropout) 184 | self.linear2 = nn.Linear(dim_feedforward, d_model) 185 | 186 | self.norm1 = nn.LayerNorm(d_model) 187 | self.norm2 = nn.LayerNorm(d_model) 188 | self.norm3 = nn.LayerNorm(d_model) 189 | self.dropout1 = nn.Dropout(dropout) 190 | self.dropout2 = nn.Dropout(dropout) 191 | self.dropout3 = nn.Dropout(dropout) 192 | 193 | self.activation = _get_activation_fn(activation) 194 | self.normalize_before = normalize_before 195 | 196 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 197 | return tensor if pos is None else tensor + pos 198 | 199 | def forward_post(self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): 200 | q = k = self.with_pos_embed(tgt, query_pos) 201 | tgt2, weights = self.self_attn(q, k, value=tgt) 202 | tgt = tgt + self.dropout1(tgt2) 203 | tgt = self.norm1(tgt) 204 | tgt2, weights = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 205 | key=self.with_pos_embed(memory, pos), 206 | value=memory) 207 | tgt = tgt + self.dropout2(tgt2) 208 | tgt = self.norm2(tgt) 209 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 210 | tgt = tgt + self.dropout3(tgt2) 211 | tgt = self.norm3(tgt) 212 | return tgt 213 | 214 | def forward_pre(self, tgt, memory, pos: Optional[Tensor] = None, 215 | query_pos: Optional[Tensor] = None): 216 | tgt2 = self.norm1(tgt) 217 | q = k = self.with_pos_embed(tgt2, query_pos) 218 | tgt2, weights = self.self_attn(q, k, value=tgt2) 219 | tgt = tgt + self.dropout1(tgt2) 220 | tgt2 = self.norm2(tgt) 221 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 222 | key=self.with_pos_embed(memory, pos), 223 | value=memory) 224 | tgt = tgt + self.dropout2(tgt2) 225 | tgt2 = self.norm3(tgt) 226 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 227 | tgt = tgt + self.dropout3(tgt2) 228 | return tgt 229 | 230 | def forward(self, tgt, memory, pos: Optional[Tensor] = None, 231 | query_pos: Optional[Tensor] = None): 232 | if self.normalize_before: 233 | return self.forward_pre(tgt, memory, pos, query_pos) 234 | return self.forward_post(tgt, memory, pos, query_pos) 235 | 236 | 237 | def _get_clones(module, N): 238 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 239 | 240 | 241 | def _get_activation_fn(activation): 242 | 243 | """Return an activation function given a string""" 244 | if activation == "relu": 245 | return F.relu 246 | if activation == "gelu": 247 | return F.gelu 248 | if activation == "glu": 249 | return F.glu 250 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 251 | 252 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | pycocotools 3 | numpy 4 | scipy 5 | matplotlib 6 | opencv-python # 4.3.0.38 7 | tqdm 8 | pytorch_pretrained_bert 9 | tensorboardX 10 | termcolor 11 | git+https://github.com/openai/CLIP.git -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import random 5 | import datetime 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.utils.data.distributed 16 | from torch.utils.data import DataLoader 17 | from torchvision.transforms import Compose, ToTensor, Normalize 18 | 19 | import torch.distributed as dist 20 | import torch.multiprocessing as mp 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | import torch.utils.data.distributed 23 | 24 | #import apex.amp as amp 25 | from torch.cuda.amp import autocast as autocast 26 | 27 | from model.model import * 28 | from engine.engine import * 29 | 30 | from dataset.data_loader import * 31 | from utils.losses import * 32 | from utils.parsing_metrics import * 33 | from utils.utils import * 34 | from utils.checkpoint import load_pretrain, load_resume 35 | from utils.logger import setup_logger 36 | 37 | def get_args(): 38 | parser = argparse.ArgumentParser(description='Dataloader test') 39 | parser.add_argument('--gpu', default='2', help='gpu id') 40 | parser.add_argument('--ngpu', default=2, type=int, help='gpu num') 41 | parser.add_argument('--workers', default=4, type=int, help='num workers for data loading') 42 | parser.add_argument('--seed', default=0, type=int, help='random seed') 43 | 44 | parser.add_argument('--clip_model', default='ViT-B/16', type=str, help='clip model RN50 RN101 ViT-B/32') 45 | parser.add_argument('--nb_epoch', default=32, type=int, help='training epoch') 46 | parser.add_argument('--lr', default=0.000025, type=float, help='batch size 16 learning rate') 47 | parser.add_argument('--power', default=0.1, type=float, help='lr poly power') 48 | parser.add_argument('--steps', default=[15, 28], type=list, help='in which step lr decay by power') 49 | parser.add_argument('--batch_size', default=16, type=int, help='batch size') 50 | parser.add_argument('--size', default=416, type=int, help='image size') 51 | parser.add_argument('--dataset', default='refcoco', type=str, 52 | help='refcoco/refcoco+/refcocog/grefcoco') 53 | 54 | parser.add_argument('--num_query', default=16, type=int, help='the number of query') 55 | parser.add_argument('--w_seg', default=0.1, type=float, help='weight of the seg loss') 56 | parser.add_argument('--w_coord', default=5, type=float, help='weight of the reg loss') 57 | parser.add_argument('--tunelang', dest='tunelang', default=True, action='store_true', help='if finetune language model') 58 | parser.add_argument('--anchor_imsize', default=416, type=int, 59 | help='scale used to calculate anchors defined in model cfg file') 60 | parser.add_argument('--data_root', type=str, default='./ln_data', 61 | help='path to ReferIt splits data folder') 62 | parser.add_argument('--split_root', type=str, default='./data', 63 | help='location of pre-parsed dataset info') 64 | parser.add_argument('--time', default=15, type=int, 65 | help='maximum time steps (lang length) per batch') 66 | parser.add_argument('--log_dir', type=str, default='./logs', 67 | help='path to ReferIt splits data folder') 68 | 69 | parser.add_argument('--fusion_dim', default=768, type=int, 70 | help='fusion module embedding dimensions') 71 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 72 | help='path to latest checkpoint (default: none)') 73 | parser.add_argument('--pretrain', default='', type=str, metavar='PATH', 74 | help='pretrain support load state_dict that are not identical, while have no loss saved as resume') 75 | parser.add_argument('--print_freq', '-p', default=100, type=int, 76 | metavar='N', help='print frequency (default: 1e3)') 77 | parser.add_argument('--savename', default='default', type=str, help='Name head for saved model') 78 | 79 | parser.add_argument('--seg_thresh', default=0.35, type=float, help='seg score above this value means foreground') 80 | parser.add_argument('--seg_out_stride', default=2, type=int, help='the seg out stride') 81 | parser.add_argument('--best_iou', default=-float('Inf'), type=int, help='the best accu') 82 | 83 | global args, anchors_full, writer, logger 84 | args = parser.parse_args() 85 | args.gsize = 32 86 | args.date = datetime.datetime.now().strftime('%Y%m%d') 87 | if args.savename=='default': 88 | args.savename = 'model_v1_%s_batch%d_%s'%(args.dataset, args.batch_size, args.date) 89 | os.makedirs(args.log_dir, exist_ok=True) 90 | args.lr = args.lr * (args.batch_size * args.ngpu // 16) 91 | 92 | print('----------------------------------------------------------------------') 93 | print(sys.argv[0]) 94 | print(args) 95 | print('----------------------------------------------------------------------') 96 | 97 | return args 98 | 99 | def main(args): 100 | os.environ['MASTER_ADDR'] = 'localhost' 101 | os.environ['MASTER_PORT'] = '12367' 102 | 103 | if(torch.cuda.is_available()): 104 | n_gpus = torch.cuda.device_count() 105 | print("Running DDP with {} GPUs".format(n_gpus)) 106 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, args,)) 107 | else: 108 | print("Please use GPU for training") 109 | 110 | def run(rank, n_gpus, args): 111 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 112 | torch.cuda.set_device(rank) 113 | 114 | ## fix seed 115 | cudnn.benchmark = False 116 | cudnn.deterministic = True 117 | random.seed(args.seed) 118 | np.random.seed(args.seed+1) 119 | torch.manual_seed(args.seed+2) 120 | torch.cuda.manual_seed_all(args.seed+3) 121 | 122 | ## save logs 123 | logger = setup_logger(output=os.path.join(args.log_dir, args.savename), distributed_rank=rank, color=False, name="model-v1") 124 | logger.info(str(sys.argv)) 125 | logger.info(str(args)) 126 | 127 | input_transform = Compose([ 128 | ToTensor(), 129 | Normalize( 130 | mean=[0.48145466, 0.4578275, 0.40821073], 131 | std=[0.26862954, 0.26130258, 0.27577711] 132 | ) 133 | ]) 134 | 135 | 136 | val_dataset = ReferDataset(data_root=args.data_root, 137 | dataset=args.dataset, 138 | split_root=args.split_root, 139 | split='val', 140 | imsize = args.size, 141 | transform=input_transform, 142 | max_query_len=args.time) 143 | 144 | 145 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, 146 | pin_memory=True, drop_last=True, num_workers=args.workers) 147 | 148 | if args.dataset == 'refcocog_u': 149 | test_dataset = ReferDataset(data_root=args.data_root, 150 | dataset=args.dataset, 151 | split_root=args.split_root, 152 | split='test', 153 | imsize = args.size, 154 | transform=input_transform, 155 | max_query_len=args.time) 156 | 157 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, 158 | pin_memory=True, drop_last=True, num_workers=args.workers) 159 | elif args.dataset == 'refcocog_g': 160 | pass 161 | else: 162 | testA_dataset = ReferDataset(data_root=args.data_root, 163 | dataset=args.dataset, 164 | split_root=args.split_root, 165 | split='testA', 166 | imsize = args.size, 167 | transform=input_transform, 168 | max_query_len=args.time) 169 | testB_dataset = ReferDataset(data_root=args.data_root, 170 | dataset=args.dataset, 171 | split_root=args.split_root, 172 | split='testB', 173 | imsize = args.size, 174 | transform=input_transform, 175 | max_query_len=args.time) 176 | 177 | 178 | testA_loader = DataLoader(testA_dataset, batch_size=1, shuffle=False, 179 | pin_memory=True, drop_last=True, num_workers=args.workers) 180 | testB_loader = DataLoader(testB_dataset, batch_size=1, shuffle=False, 181 | pin_memory=True, drop_last=True, num_workers=args.workers) 182 | 183 | 184 | ## Model 185 | model = Model(clip_model=args.clip_model, tunelang=args.tunelang, num_query=args.num_query, fusion_dim=args.fusion_dim).cuda(rank) 186 | model = DDP(model, device_ids=[rank], find_unused_parameters=True) 187 | model_without_ddp = model.module 188 | 189 | args.start_epoch = 0 190 | if args.pretrain and os.path.isfile(args.pretrain): 191 | model=load_pretrain(model, args, logger, rank) 192 | model.to(rank) 193 | 194 | visu_param = [param for name, param in model_without_ddp.named_parameters() if 'visumodel' in name] 195 | text_param = [param for name, param in model_without_ddp.named_parameters() if 'textmodel' in name] 196 | rest_param = [param for name, param in model_without_ddp.named_parameters() if 'textmodel' not in name and 'visumodel' not in name] 197 | 198 | 199 | ## optimizer; adam default 200 | if args.tunelang: 201 | optimizer = torch.optim.Adam([{'params': rest_param, 'lr': args.lr}, 202 | {'params': visu_param, 'lr': args.lr / 10.}, 203 | {'params': text_param, 'lr': args.lr / 10.}]) 204 | else: 205 | optimizer = torch.optim.Adam([{'params': rest_param}, 206 | {'params': visu_param, 'lr': args.lr / 10.}], lr=args.lr) 207 | 208 | 209 | 210 | best_miou_seg = -float('Inf') 211 | if args.resume: 212 | model = load_resume(model, optimizer, args, logger, rank) 213 | model.to(rank) 214 | best_miou_seg = args.best_iou 215 | print(best_miou_seg) 216 | 217 | if args.dataset == 'refcocog_u': 218 | print('\nTest testing:') 219 | miou_seg, prec = validate_epoch(args, test_loader, model, logger, 'test') 220 | 221 | elif args.dataset == 'refcocog_g': 222 | pass 223 | else: 224 | print('\nTestA testing:') 225 | miou_seg, prec = validate_epoch(args, testA_loader, model, logger, 'testA') 226 | print('\nTestB testing:') 227 | miou_seg, prec = validate_epoch(args, testB_loader, model, logger, 'testB') 228 | miou_seg, prec = validate_epoch(args, val_loader, model, logger, 'val') 229 | 230 | if __name__ == "__main__": 231 | args = get_args() 232 | main(args) 233 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | ngpu=$1 2 | CUDA_VISIBLE_DEVICES=${ngpu} python test.py --dataset refcoco --savename savename --resume 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import random 5 | import datetime 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.utils.data.distributed 16 | from torch.utils.data import DataLoader 17 | from torchvision.transforms import Compose, ToTensor, Normalize 18 | 19 | import torch.distributed as dist 20 | import torch.multiprocessing as mp 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | import torch.utils.data.distributed 23 | 24 | from tensorboardX import SummaryWriter 25 | 26 | #import apex.amp as amp 27 | from torch.cuda.amp import autocast as autocast, GradScaler 28 | 29 | from model.model import * 30 | from engine.engine import * 31 | 32 | from dataset.data_loader import * 33 | from utils.losses import * 34 | from utils.parsing_metrics import * 35 | from utils.utils import * 36 | from utils.checkpoint import save_checkpoint, load_pretrain, load_resume 37 | from utils.logger import setup_logger 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser(description='Dataloader test') 41 | parser.add_argument('--gpu', default='2', help='gpu id') 42 | parser.add_argument('--ngpu', default=2, type=int, help='gpu num') 43 | parser.add_argument('--workers', default=4, type=int, help='num workers for data loading') 44 | parser.add_argument('--seed', default=0, type=int, help='random seed') 45 | 46 | parser.add_argument('--clip_model', default='ViT-B/16', type=str, help='clip model RN50 RN101 ViT-B/32') 47 | parser.add_argument('--nb_epoch', default=32, type=int, help='training epoch') 48 | parser.add_argument('--lr', default=0.000025, type=float, help='batch size 16 learning rate') 49 | parser.add_argument('--power', default=0.1, type=float, help='lr poly power') 50 | parser.add_argument('--steps', default=[18, 28], type=list, help='in which step lr decay by power') 51 | parser.add_argument('--batch_size', default=16, type=int, help='batch size') 52 | parser.add_argument('--size', default=416, type=int, help='image size') 53 | parser.add_argument('--dataset', default='grefcoco', type=str, 54 | help='refcoco/refcoco+/refcocog/grefcoco') 55 | 56 | parser.add_argument('--num_query', default=16, type=int, help='the number of query') 57 | parser.add_argument('--w_seg', default=0.1, type=float, help='weight of the seg loss') 58 | parser.add_argument('--w_coord', default=5, type=float, help='weight of the reg loss') 59 | parser.add_argument('--tunelang', dest='tunelang', default=True, action='store_true', help='if finetune language model') 60 | parser.add_argument('--anchor_imsize', default=416, type=int, 61 | help='scale used to calculate anchors defined in model cfg file') 62 | parser.add_argument('--data_root', type=str, default='./ln_data', 63 | help='path to ReferIt splits data folder') 64 | parser.add_argument('--split_root', type=str, default='./data', 65 | help='location of pre-parsed dataset info') 66 | parser.add_argument('--time', default=17, type=int, 67 | help='maximum time steps (lang length) per batch') 68 | parser.add_argument('--log_dir', type=str, default='./logs', 69 | help='path to ReferIt splits data folder') 70 | 71 | parser.add_argument('--fusion_dim', default=768, type=int, 72 | help='fusion module embedding dimensions') 73 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 74 | help='path to latest checkpoint (default: none)') 75 | parser.add_argument('--pretrain', default='', type=str, metavar='PATH', 76 | help='pretrain support load state_dict that are not identical, while have no loss saved as resume') 77 | parser.add_argument('--print_freq', '-p', default=100, type=int, 78 | metavar='N', help='print frequency (default: 1e3)') 79 | parser.add_argument('--savename', default='default', type=str, help='Name head for saved model') 80 | 81 | parser.add_argument('--seg_thresh', default=0.35, type=float, help='seg score above this value means foreground') 82 | parser.add_argument('--seg_out_stride', default=2, type=int, help='the seg out stride') 83 | parser.add_argument('--best_iou', default=-float('Inf'), type=int, help='the best accu') 84 | 85 | global args, anchors_full, writer, logger 86 | args = parser.parse_args() 87 | args.gsize = 32 88 | args.date = datetime.datetime.now().strftime('%Y%m%d') 89 | if args.savename=='default': 90 | args.savename = 'model_v1_%s_batch%d_%s'%(args.dataset, args.batch_size, args.date) 91 | os.makedirs(args.log_dir, exist_ok=True) 92 | args.lr = round(args.lr * (args.batch_size * args.ngpu / 16), 6) 93 | print('----------------------------------------------------------------------') 94 | print(sys.argv[0]) 95 | print(args) 96 | print('----------------------------------------------------------------------') 97 | 98 | return args 99 | 100 | def main(args): 101 | os.environ['MASTER_ADDR'] = 'localhost' 102 | os.environ['MASTER_PORT'] = '12356' 103 | 104 | if(torch.cuda.is_available()): 105 | n_gpus = torch.cuda.device_count() 106 | print("Running DDP with {} GPUs".format(n_gpus)) 107 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, args,)) 108 | else: 109 | print("Please use GPU for training") 110 | 111 | def run(rank, n_gpus, args): 112 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 113 | torch.cuda.set_device(rank) 114 | 115 | ## fix seed 116 | cudnn.benchmark = False 117 | cudnn.deterministic = True 118 | random.seed(args.seed) 119 | np.random.seed(args.seed+1) 120 | torch.manual_seed(args.seed+2) 121 | torch.cuda.manual_seed_all(args.seed+3) 122 | 123 | ## save logs 124 | logger = setup_logger(output=os.path.join(args.log_dir, args.savename), distributed_rank=rank, color=False, name="model-v1") 125 | 126 | logger.info(str(sys.argv)) 127 | logger.info(str(args)) 128 | if rank == 0: 129 | writer = SummaryWriter(comment=args.savename) 130 | 131 | input_transform = Compose([ 132 | ToTensor(), 133 | Normalize( 134 | mean=[0.48145466, 0.4578275, 0.40821073], 135 | std=[0.26862954, 0.26130258, 0.27577711] 136 | ) 137 | ]) 138 | 139 | train_dataset = ReferDataset(data_root=args.data_root, 140 | dataset=args.dataset, 141 | split_root=args.split_root, 142 | split='train', 143 | imsize = args.size, 144 | transform=input_transform, 145 | max_query_len=args.time, 146 | augment=True) 147 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=n_gpus, rank=rank, shuffle=True) 148 | 149 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, 150 | pin_memory=True, drop_last=True, num_workers=args.workers, sampler=train_sampler) 151 | 152 | 153 | if rank == 0: 154 | val_dataset = ReferDataset(data_root=args.data_root, 155 | dataset=args.dataset, 156 | split_root=args.split_root, 157 | split='val', 158 | imsize = args.size, 159 | transform=input_transform, 160 | max_query_len=args.time) 161 | 162 | 163 | 164 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, 165 | pin_memory=True, drop_last=True, num_workers=args.workers) 166 | 167 | 168 | ## Model 169 | model = Model(clip_model=args.clip_model, tunelang=args.tunelang, num_query=args.num_query, fusion_dim=args.fusion_dim).cuda(rank) 170 | model = DDP(model, device_ids=[rank], find_unused_parameters=True) 171 | model_without_ddp = model.module 172 | 173 | args.start_epoch = 0 174 | if args.pretrain and os.path.isfile(args.pretrain): 175 | model=load_pretrain(model,args,logger, rank) 176 | model.to(rank) 177 | 178 | visu_param = [param for name, param in model_without_ddp.named_parameters() if 'visumodel' in name] 179 | text_param = [param for name, param in model_without_ddp.named_parameters() if 'textmodel' in name] 180 | rest_param = [param for name, param in model_without_ddp.named_parameters() if 'textmodel' not in name and 'visumodel' not in name] 181 | 182 | sum_visu = sum([param.nelement() for param in visu_param]) 183 | sum_text = sum([param.nelement() for param in text_param]) 184 | sum_fusion = sum([param.nelement() for param in rest_param]) 185 | if rank == 0: 186 | print('Num of parameters:', sum([param.nelement() for param in model_without_ddp.parameters()])) 187 | logger.info('Num of parameters:%d'%int(sum([param.nelement() for param in model_without_ddp.parameters()]))) 188 | print('visu, text, fusion module parameters:', sum_visu, sum_text, sum_fusion) 189 | 190 | ## optimizer; adam default 191 | if args.tunelang: 192 | optimizer = torch.optim.Adam([{'params': rest_param, 'lr': args.lr}, 193 | {'params': visu_param, 'lr': args.lr / 10.}, 194 | {'params': text_param, 'lr': args.lr / 10.}]) 195 | else: 196 | optimizer = torch.optim.Adam([{'params': rest_param}, 197 | {'params': visu_param, 'lr': args.lr / 10.}], lr=args.lr) 198 | 199 | # Initialization 200 | scaler = GradScaler() 201 | 202 | best_miou_seg = -float('Inf') 203 | if args.resume: 204 | model = load_resume(model, optimizer, args, logger, rank) 205 | model.to(rank) 206 | best_miou_seg = args.best_iou 207 | print(best_miou_seg) 208 | 209 | 210 | for epoch in range(args.start_epoch, args.nb_epoch): 211 | adjust_learning_rate(args, optimizer, epoch) 212 | loss = train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger) 213 | if rank == 0: 214 | writer.add_scalar('loss', loss, global_step=epoch) 215 | miou_seg = 0 216 | if epoch == 0 or epoch > 10: 217 | miou_seg, prec = validate_epoch(args, val_loader, model, logger, 'Val') 218 | writer.add_scalar('miou_seg', miou_seg, global_step=epoch) 219 | thresholds = np.arange(0.5, 1, 0.05) 220 | for thresh in thresholds: 221 | writer.add_scalar('prec@%f'%thresh, prec[thresh].avg, global_step=epoch) 222 | 223 | ## remember best accu and save checkpoint 224 | is_best = miou_seg > best_miou_seg 225 | best_miou_seg= max(miou_seg, best_miou_seg) 226 | save_checkpoint({ 227 | 'epoch': epoch + 1, 228 | 'state_dict': model.module.state_dict(), 229 | 'best_iou': best_miou_seg, 230 | 'optimizer' : optimizer.state_dict(), 231 | }, is_best, args, filename=args.savename) 232 | print('\nBest Accu: %f\n'%best_miou_seg) 233 | logger.info('\nBest Accu: %f\n'%best_miou_seg) 234 | 235 | if __name__ == "__main__": 236 | args = get_args() 237 | main(args) 238 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=${1} python train.py --dataset refcoco --ngpu 2 --batch_size 14 --time 17 --savename savename -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | 5 | def save_checkpoint(state, is_best, args, filename='default'): 6 | if filename=='default': 7 | filename = 'mcn_%s_batch%d'%(args.dataset,args.samples_per_gpu) 8 | print("=> saving checkpoint '{}'".format(filename)) 9 | if not os.path.exists('./saved_models'): 10 | os.makedirs('./saved_models') 11 | checkpoint_name = './saved_models/%s_checkpoint.pth.tar'%(filename) 12 | best_name = './saved_models/%s_model_best.pth.tar'%(filename) 13 | torch.save(state, checkpoint_name) 14 | if is_best: 15 | print("=> saving best model '{}'".format(best_name)) 16 | shutil.copyfile(checkpoint_name, best_name) 17 | 18 | def load_pretrain(model, args, logging, rank): 19 | if os.path.isfile(args.pretrain): 20 | checkpoint = torch.load(args.pretrain) 21 | pretrained_dict = checkpoint['state_dict'] 22 | if hasattr(model, 'module'): 23 | model_dict = model.module.state_dict() 24 | else: 25 | model_dict = model.state_dict() 26 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 27 | assert (len([k for k, v in pretrained_dict.items()])!=0) 28 | model_dict.update(pretrained_dict) 29 | if hasattr(model, 'module'): 30 | model.module.load_state_dict(model_dict) 31 | else: 32 | model.load_state_dict(model_dict) 33 | print("=> loaded pretrain model at {}" 34 | .format(args.pretrain)) 35 | if rank == 0: 36 | logging.info("=> loaded pretrain model at {}" 37 | .format(args.pretrain)) 38 | del checkpoint # dereference seems crucial 39 | torch.cuda.empty_cache() 40 | else: 41 | print(("=> no pretrained file found at '{}'".format(args.pretrain))) 42 | if rank == 0: 43 | logging.info("=> no pretrained file found at '{}'".format(args.pretrain)) 44 | return model 45 | 46 | def load_pretrain_ddp(model, args): 47 | if os.path.isfile(args.pretrain): 48 | checkpoint = torch.load(args.pretrain) 49 | pretrained_dict = checkpoint['state_dict'] 50 | model_dict = model.state_dict() 51 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 52 | assert (len([k for k, v in pretrained_dict.items()])!=0) 53 | model_dict.update(pretrained_dict) 54 | if hasattr(model, 'module'): 55 | state_dict = model.module.state_dict() 56 | model.module.load_state_dict(model_dict) 57 | else: 58 | state_dict = model.state_dict() 59 | model.load_state_dict(model_dict) 60 | print("load ") 61 | print("=> loaded pretrain model at {}" 62 | .format(args.pretrain)) 63 | del checkpoint # dereference seems crucial 64 | torch.cuda.empty_cache() 65 | else: 66 | print(("=> no pretrained file found at '{}'".format(args.pretrain))) 67 | return model 68 | 69 | 70 | def load_resume(model, optimizer, args, logging, rank): 71 | if os.path.isfile(args.resume): 72 | print(("=> loading checkpoint '{}'".format(args.resume))) 73 | if rank == 0: 74 | logging.info("=> loading checkpoint '{}'".format(args.resume)) 75 | checkpoint = torch.load(args.resume, map_location='cpu') 76 | args.start_epoch = checkpoint['epoch'] 77 | print("epoch: ", args.start_epoch) 78 | args.best_iou = checkpoint['best_iou'] 79 | print("best iou: ", args.best_iou) 80 | state_dict = checkpoint['state_dict'] 81 | 82 | if hasattr(model, 'module'): 83 | model_dict = model.module.state_dict() 84 | else: 85 | model_dict = model.state_dict() 86 | new_state_dict = {k:v for k,v in state_dict.items() if k in model_dict} 87 | model_dict.update(new_state_dict) 88 | 89 | 90 | if hasattr(model, 'module'): 91 | model.module.load_state_dict(model_dict) 92 | else: 93 | model.load_state_dict(model_dict) 94 | optimizer.load_state_dict(checkpoint['optimizer']) 95 | del checkpoint # dereference seems crucial 96 | torch.cuda.empty_cache() 97 | print("load successfully!") 98 | else: 99 | print(("=> no checkpoint found at '{}'".format(args.resume))) 100 | if rank == 0: 101 | logging.info(("=> no checkpoint found at '{}'".format(args.resume))) 102 | return model 103 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | from termcolor import colored 7 | 8 | class _ColorfulFormatter(logging.Formatter): 9 | def __init__(self, *args, **kwargs): 10 | self._root_name = kwargs.pop("root_name") + "." 11 | self._abbrev_name = kwargs.pop("abbrev_name", "") 12 | if len(self._abbrev_name): 13 | self._abbrev_name = self._abbrev_name + "." 14 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 15 | 16 | def formatMessage(self, record): 17 | record.name = record.name.replace(self._root_name, self._abbrev_name) 18 | log = super(_ColorfulFormatter, self).formatMessage(record) 19 | if record.levelno == logging.WARNING: 20 | prefix = colored("WARNING", "red", attrs=["blink"]) 21 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 22 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 23 | else: 24 | return log 25 | return prefix + " " + log 26 | 27 | 28 | # so that calling setup_logger multiple times won't add many handlers 29 | @functools.lru_cache() 30 | def setup_logger( 31 | output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None 32 | ): 33 | """ 34 | Initialize the detectron2 logger and set its verbosity level to "INFO". 35 | 36 | Args: 37 | output (str): a file name or a directory to save log. If None, will not save log file. 38 | If ends with ".txt" or ".log", assumed to be a file name. 39 | Otherwise, logs will be saved to `output/log.txt`. 40 | name (str): the root module name of this logger 41 | 42 | Returns: 43 | logging.Logger: a logger 44 | """ 45 | logger = logging.getLogger(name) 46 | logger.setLevel(logging.DEBUG) 47 | logger.propagate = False 48 | 49 | if abbrev_name is None: 50 | abbrev_name = name 51 | 52 | plain_formatter = logging.Formatter( 53 | '[%(asctime)s.%(msecs)03d]: %(message)s', 54 | datefmt='%m/%d %H:%M:%S' 55 | ) 56 | # stdout logging: master only 57 | if distributed_rank == 0: 58 | ch = logging.StreamHandler(stream=sys.stdout) 59 | ch.setLevel(logging.DEBUG) 60 | if color: 61 | formatter = _ColorfulFormatter( 62 | colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s", 63 | datefmt="%m/%d %H:%M:%S", 64 | root_name=name, 65 | abbrev_name=str(abbrev_name), 66 | ) 67 | else: 68 | formatter = plain_formatter 69 | ch.setFormatter(formatter) 70 | logger.addHandler(ch) 71 | 72 | # file logging: all workers 73 | if output is not None: 74 | if output.endswith(".txt") or output.endswith(".log"): 75 | filename = output 76 | else: 77 | filename = os.path.join(output, "log.txt") 78 | if distributed_rank > 0: 79 | filename = filename + f".rank{distributed_rank}" 80 | os.makedirs(os.path.dirname(filename), exist_ok=True) 81 | 82 | fh = logging.StreamHandler(_cached_log_stream(filename)) 83 | fh.setLevel(logging.DEBUG) 84 | fh.setFormatter(plain_formatter) 85 | logger.addHandler(fh) 86 | 87 | return logger 88 | 89 | 90 | # cache the opened file object, so that different calls to `setup_logger` 91 | # with the same file name can safely write to the same file. 92 | @functools.lru_cache(maxsize=None) 93 | def _cached_log_stream(filename): 94 | return open(filename, "a") 95 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Custom loss function definitions. 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from utils.utils import * 13 | 14 | class IoULoss(nn.Module): 15 | """ 16 | Creates a criterion that computes the Intersection over Union (IoU) 17 | between a segmentation mask and its ground truth. 18 | 19 | Rahman, M.A. and Wang, Y: 20 | Optimizing Intersection-Over-Union in Deep Neural Networks for 21 | Image Segmentation. International Symposium on Visual Computing (2016) 22 | http://www.cs.umanitoba.ca/~ywang/papers/isvc16.pdf 23 | """ 24 | 25 | def __init__(self, size_average=True): 26 | super().__init__() 27 | self.size_average = size_average 28 | 29 | def forward(self, input, target): 30 | input = F.sigmoid(input) 31 | intersection = (input * target).sum() 32 | union = ((input + target) - (input * target)).sum() 33 | iou = intersection / union 34 | iou_dual = input.size(0) - iou 35 | if self.size_average: 36 | iou_dual = iou_dual / input.size(0) 37 | return iou_dual 38 | 39 | 40 | def yolo_loss(input, target, gi, gj, best_n_list, w_coord=5.): 41 | mseloss = torch.nn.MSELoss(size_average=True) 42 | celoss = torch.nn.CrossEntropyLoss(size_average=True) 43 | batch = input.size(0) 44 | 45 | pred_bbox = Variable(torch.zeros(batch,4).cuda()) 46 | gt_bbox = Variable(torch.zeros(batch,4).cuda()) 47 | for ii in range(batch): 48 | pred_bbox[ii, 0:2] = F.sigmoid(input[ii,best_n_list[ii],0:2,gj[ii],gi[ii]]) 49 | pred_bbox[ii, 2:4] = input[ii,best_n_list[ii],2:4,gj[ii],gi[ii]] 50 | gt_bbox[ii, :] = target[ii,best_n_list[ii],:4,gj[ii],gi[ii]] 51 | loss_x = mseloss(pred_bbox[:,0], gt_bbox[:,0]) 52 | loss_y = mseloss(pred_bbox[:,1], gt_bbox[:,1]) 53 | loss_w = mseloss(pred_bbox[:,2], gt_bbox[:,2]) 54 | loss_h = mseloss(pred_bbox[:,3], gt_bbox[:,3]) 55 | 56 | pred_conf_list, gt_conf_list = [], [] 57 | pred_conf_list.append(input[:,:,4,:,:].contiguous().view(batch,-1)) 58 | gt_conf_list.append(target[:,:,4,:,:].contiguous().view(batch,-1)) 59 | pred_conf = torch.cat(pred_conf_list, dim=1) 60 | gt_conf = torch.cat(gt_conf_list, dim=1) 61 | loss_conf = celoss(pred_conf, gt_conf.max(1)[1]) 62 | return (loss_x+loss_y+loss_w+loss_h)*w_coord + loss_conf 63 | 64 | def build_target(raw_coord, anchors, args): 65 | coord = Variable(torch.zeros(raw_coord.size(0), raw_coord.size(1)).cuda()) 66 | batch, grid = raw_coord.size(0), args.size//args.gsize 67 | coord[:,0] = (raw_coord[:,0] + raw_coord[:,2])/(2*args.size) # x 相对原图归一化 68 | coord[:,1] = (raw_coord[:,1] + raw_coord[:,3])/(2*args.size) # y 69 | coord[:,2] = (raw_coord[:,2] - raw_coord[:,0])/(args.size) # w 70 | coord[:,3] = (raw_coord[:,3] - raw_coord[:,1])/(args.size) # h 71 | coord = coord * grid 72 | bbox=torch.zeros(coord.size(0),len(anchors),5,grid,grid) 73 | 74 | best_n_list, best_gi, best_gj = [],[],[] 75 | 76 | for ii in range(batch): 77 | gi = coord[ii,0].long() 78 | gj = coord[ii,1].long() 79 | tx = coord[ii,0] - gi.float() 80 | ty = coord[ii,1] - gj.float() 81 | gw = coord[ii,2] 82 | gh = coord[ii,3] 83 | 84 | scaled_anchors = [ (x[0] / (args.anchor_imsize/grid), \ 85 | x[1] / (args.anchor_imsize/grid)) for x in anchors] 86 | 87 | ## Get shape of gt box 88 | gt_box = torch.FloatTensor(np.array([0, 0, gw, gh],dtype=np.float32)).unsqueeze(0) #[1,4] 89 | ## Get shape of anchor box 90 | anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((len(scaled_anchors), 2)), np.array(scaled_anchors)), 1)) 91 | ## Calculate iou between gt and anchor shapes 92 | anch_ious = list(bbox_iou(gt_box, anchor_shapes,x1y1x2y2=False)) 93 | ## Find the best matching anchor box 94 | best_n = np.argmax(np.array(anch_ious)) 95 | 96 | tw = torch.log(gw / scaled_anchors[best_n][0] + 1e-16) 97 | th = torch.log(gh / scaled_anchors[best_n][1] + 1e-16) 98 | 99 | bbox[ii, best_n, :, gj, gi] = torch.stack([tx, ty, tw, th, torch.ones(1).cuda().squeeze()]) 100 | best_n_list.append(int(best_n)) 101 | best_gi.append(gi) 102 | best_gj.append(gj) 103 | bbox = Variable(bbox.cuda()) 104 | return bbox, best_gi, best_gj, best_n_list 105 | 106 | def adjust_learning_rate(args, optimizer, i_iter): 107 | # print(optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr']) 108 | if i_iter in args.steps: 109 | #lr = args.lr * args.power 110 | lr = args.lr * args.power ** (args.steps.index(i_iter) + 1) 111 | optimizer.param_groups[0]['lr'] = lr 112 | if len(optimizer.param_groups) > 1: 113 | optimizer.param_groups[1]['lr'] = lr / 10 114 | if len(optimizer.param_groups) > 2: 115 | optimizer.param_groups[2]['lr'] = lr / 10 116 | 117 | def cem_loss(co_energy): 118 | loss = -1.0 * torch.log(co_energy+1e-6).sum() 119 | return loss 120 | 121 | class FocalLoss(nn.Module): 122 | def __init__(self, alpha=0.25, gamma=2, logits=True, reduce=False): 123 | super(FocalLoss, self).__init__() 124 | self.alpha = alpha 125 | self.gamma = gamma 126 | self.logits = logits 127 | self.reduce = reduce 128 | 129 | def forward(self, inputs, targets): 130 | if self.logits: 131 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 132 | else: 133 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 134 | pt = torch.exp(-BCE_loss) 135 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 136 | if self.reduce: 137 | return torch.mean(F_loss) 138 | else: 139 | return torch.sum(F_loss) -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Misc download and visualization helper functions and class wrappers. 5 | """ 6 | 7 | import sys 8 | import time 9 | import torch 10 | from visdom import Visdom 11 | 12 | 13 | def reporthook(count, block_size, total_size): 14 | global start_time 15 | if count == 0: 16 | start_time = time.time() 17 | return 18 | duration = time.time() - start_time 19 | progress_size = int(count * block_size) 20 | speed = int(progress_size / (1024 * duration)) 21 | percent = min(int(count * block_size * 100 / total_size), 100) 22 | sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % 23 | (percent, progress_size / (1024 * 1024), speed, duration)) 24 | sys.stdout.flush() 25 | 26 | 27 | class VisdomWrapper(Visdom): 28 | def __init__(self, *args, env=None, **kwargs): 29 | Visdom.__init__(self, *args, **kwargs) 30 | self.env = env 31 | self.plots = {} 32 | 33 | def init_line_plot(self, name, 34 | X=torch.zeros((1,)).cpu(), 35 | Y=torch.zeros((1,)).cpu(), **opts): 36 | self.plots[name] = self.line(X=X, Y=Y, env=self.env, opts=opts) 37 | 38 | def plot_line(self, name, **kwargs): 39 | self.line(win=self.plots[name], env=self.env, **kwargs) 40 | -------------------------------------------------------------------------------- /utils/parsing_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def _fast_hist(label_true, label_pred, n_class): 5 | mask = (label_true >= 0) & (label_true < n_class) 6 | hist = np.bincount( 7 | n_class * label_true[mask].astype(int) + 8 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 9 | return hist 10 | 11 | def label_accuracy_score(label_trues, label_preds, n_class, bg_thre=200): 12 | """Returns accuracy score evaluation result. 13 | - overall accuracy 14 | - mean accuracy 15 | - mean IU 16 | - fwavacc 17 | """ 18 | hist = np.zeros((n_class, n_class)) 19 | for lt, lp in zip(label_trues, label_preds): 20 | # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 21 | hist += _fast_hist(lt[lt 0] * iu[freq > 0]).sum() 29 | return acc, acc_cls, mean_iu, fwavacc 30 | 31 | def label_confusion_matrix(label_trues, label_preds, n_class, bg_thre=200): 32 | # eps=1e-20 33 | hist=np.zeros((n_class,n_class),dtype=float) 34 | """ (8,256,256), (256,256) """ 35 | for lt,lp in zip(label_trues, label_preds): 36 | # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 37 | hist += _fast_hist(lt[lt 0] * iu[freq > 0]).sum() 68 | return acc, acc_cls, mean_iu, fwavacc, iu 69 | 70 | def cal_seg_iou_loss(gt,pred,trsh=0.5): 71 | t=np.array(pred>trsh) 72 | p=np.array(gt>0.) 73 | intersection = np.logical_and(t, p) 74 | union = np.logical_or(t, p) 75 | iou = (np.sum(intersection > 0 , axis=(2,3)) + 1e-10 )/ (np.sum(union > 0, axis=(2,3)) + 1e-10) 76 | return iou 77 | 78 | def cal_seg_iou(gt,pred,trsh=0.5): 79 | #(gt.shape) [1 428 640] 80 | #(pred.shape) [428 640] 81 | t=np.array(pred>trsh) 82 | p=np.array(gt>0.) 83 | intersection = np.logical_and(t, p) 84 | union = np.logical_or(t, p) 85 | iou = (np.sum(intersection > 0) + 1e-10 )/ (np.sum(union > 0) + 1e-10) 86 | 87 | prec=dict() 88 | thresholds = np.arange(0.5, 1, 0.05) 89 | for thresh in thresholds: 90 | prec[thresh]= float(iou > thresh) 91 | return iou,prec 92 | 93 | def cal_seg_iou2(gt,pred,trsh=0.5): 94 | #(gt.shape) [1 428 640] 95 | #(pred.shape) [428 640] 96 | t=np.array(pred>trsh) 97 | p=np.array(gt>0.) 98 | intersection = np.logical_and(t, p) 99 | union = np.logical_or(t, p) 100 | iou = (np.sum(intersection > 0) + 1e-10 )/ (np.sum(union > 0) + 1e-10) 101 | 102 | prec=dict() 103 | thresholds = np.arange(0.5, 1, 0.05) 104 | for thresh in thresholds: 105 | prec[thresh]= float(iou > thresh) 106 | return iou, prec, np.sum(intersection > 0), np.sum(union > 0) -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Generic Image Transform utillities. 5 | """ 6 | 7 | import cv2 8 | import random, math 9 | import numpy as np 10 | from collections.abc import Iterable 11 | from torch import rand 12 | 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | 16 | 17 | class ResizePad: 18 | """ 19 | Resize and pad an image to given size. 20 | """ 21 | 22 | def __init__(self, size): 23 | if not isinstance(size, (int, Iterable)): 24 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 25 | 26 | self.h, self.w = size 27 | 28 | def __call__(self, img): 29 | h, w = img.shape[:2] 30 | scale = min(self.h / h, self.w / w) 31 | resized_h = int(np.round(h * scale)) 32 | resized_w = int(np.round(w * scale)) 33 | pad_h = int(np.floor(self.h - resized_h) / 2) 34 | pad_w = int(np.floor(self.w - resized_w) / 2) 35 | 36 | resized_img = cv2.resize(img, (resized_w, resized_h)) 37 | 38 | # if img.ndim > 2: 39 | if img.ndim > 2: 40 | new_img = np.zeros( 41 | (self.h, self.w, img.shape[-1]), dtype=resized_img.dtype) 42 | else: 43 | resized_img = np.expand_dims(resized_img, -1) 44 | new_img = np.zeros((self.h, self.w, 1), dtype=resized_img.dtype) 45 | new_img[pad_h: pad_h + resized_h, 46 | pad_w: pad_w + resized_w, ...] = resized_img 47 | return new_img 48 | 49 | 50 | class CropResize: 51 | """Remove padding and resize image to its original size.""" 52 | 53 | def __call__(self, img, size): 54 | if not isinstance(size, (int, Iterable)): 55 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 56 | im_h, im_w = img.data.shape[:2] 57 | input_h, input_w = size 58 | scale = max(input_h / im_h, input_w / im_w) 59 | # scale = torch.Tensor([[input_h / im_h, input_w / im_w]]).max() 60 | resized_h = int(np.round(im_h * scale)) 61 | # resized_h = torch.round(im_h * scale) 62 | resized_w = int(np.round(im_w * scale)) 63 | # resized_w = torch.round(im_w * scale) 64 | crop_h = int(np.floor(resized_h - input_h) / 2) 65 | # crop_h = torch.floor(resized_h - input_h) // 2 66 | crop_w = int(np.floor(resized_w - input_w) / 2) 67 | # crop_w = torch.floor(resized_w - input_w) // 2 68 | # resized_img = cv2.resize(img, (resized_w, resized_h)) 69 | resized_img = F.upsample( 70 | img.unsqueeze(0).unsqueeze(0), size=(resized_h, resized_w), 71 | mode='bilinear') 72 | 73 | resized_img = resized_img.squeeze().unsqueeze(0) 74 | 75 | return resized_img[0, crop_h: crop_h + input_h, 76 | crop_w: crop_w + input_w] 77 | 78 | 79 | class ResizeImage: 80 | """Resize the largest of the sides of the image to a given size""" 81 | def __init__(self, size): 82 | if not isinstance(size, (int, Iterable)): 83 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 84 | 85 | self.size = size 86 | 87 | def __call__(self, img): 88 | im_h, im_w = img.shape[-2:] 89 | scale = min(self.size / im_h, self.size / im_w) 90 | resized_h = int(np.round(im_h * scale)) 91 | resized_w = int(np.round(im_w * scale)) 92 | out = F.upsample( 93 | Variable(img).unsqueeze(0), size=(resized_h, resized_w), 94 | mode='bilinear').squeeze().data 95 | return out 96 | 97 | 98 | class ResizeAnnotation: 99 | """Resize the largest of the sides of the annotation to a given size""" 100 | def __init__(self, size): 101 | if not isinstance(size, (int, Iterable)): 102 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 103 | 104 | self.size = size 105 | 106 | def __call__(self, img): 107 | im_h, im_w = img.shape[-2:] 108 | scale = min(self.size / im_h, self.size / im_w) 109 | resized_h = int(np.round(im_h * scale)) 110 | resized_w = int(np.round(im_w * scale)) 111 | out = F.upsample( 112 | Variable(img).unsqueeze(0).unsqueeze(0), 113 | size=(resized_h, resized_w), 114 | mode='bilinear').squeeze().data 115 | return out 116 | 117 | 118 | class ToNumpy: 119 | """Transform an torch.*Tensor to an numpy ndarray.""" 120 | 121 | def __call__(self, x): 122 | return x.numpy() 123 | 124 | def letterbox(img, mask, height, color=(123.7, 116.3, 103.5)): # resize a rectangular image to a padded square 125 | shape = img.shape[:2] # shape = [height, width] 126 | ratio = float(height) / max(shape) # ratio = old / new 127 | new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) 128 | dw = (height - new_shape[0]) / 2 # width padding 129 | dh = (height - new_shape[1]) / 2 # height padding 130 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 131 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 132 | img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border 133 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square 134 | if mask is not None: 135 | mask = cv2.resize(mask, new_shape, interpolation=cv2.INTER_NEAREST) # resized, no border 136 | mask = cv2.copyMakeBorder(mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) # padded square 137 | return img, mask, ratio, dw, dh 138 | 139 | 140 | def random_affine(img, mask, targets, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2), 141 | borderValue=(123.7, 116.3, 103.5), all_bbox=None): 142 | border = 0 # width of added border (optional) 143 | height = max(img.shape[0], img.shape[1]) + border * 2 144 | 145 | # Rotation and Scale 146 | R = np.eye(3) 147 | a = random.random() * (degrees[1] - degrees[0]) + degrees[0] 148 | # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations 149 | s = random.random() * (scale[1] - scale[0]) + scale[0] 150 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s) 151 | 152 | # Translation 153 | T = np.eye(3) 154 | T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border # x translation (pixels) 155 | T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border # y translation (pixels) 156 | 157 | # Shear 158 | S = np.eye(3) 159 | S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg) 160 | S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg) 161 | 162 | M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!! 163 | imw = cv2.warpPerspective(img, M, dsize=(height, height), flags=cv2.INTER_LINEAR, 164 | borderValue=borderValue) # BGR order borderValue 165 | if mask is not None: 166 | maskw = cv2.warpPerspective(mask, M, dsize=(height, height), flags=cv2.INTER_NEAREST, 167 | borderValue=0) # BGR order borderValue 168 | else: 169 | maskw = None 170 | 171 | # Return warped points also 172 | if type(targets)==type([1]): 173 | targetlist=[] 174 | for bbox in targets: 175 | targetlist.append(wrap_points(bbox, M, height, a)) 176 | return imw, maskw, targetlist, M 177 | elif all_bbox is not None: 178 | targets = wrap_points(targets, M, height, a) 179 | for ii in range(all_bbox.shape[0]): 180 | all_bbox[ii,:] = wrap_points(all_bbox[ii,:], M, height, a) 181 | return imw, maskw, targets, all_bbox, M 182 | elif targets is not None: ## previous main 183 | targets = wrap_points(targets, M, height, a) 184 | return imw, maskw, targets, M 185 | else: 186 | return imw 187 | 188 | def wrap_points(targets, M, height, a): 189 | # n = targets.shape[0] 190 | # points = targets[:, 1:5].copy() 191 | points = targets.copy() 192 | # area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1]) 193 | area0 = (points[2] - points[0]) * (points[3] - points[1]) 194 | 195 | # warp points 196 | xy = np.ones((4, 3)) 197 | xy[:, :2] = points[[0, 1, 2, 3, 0, 3, 2, 1]].reshape(4, 2) # x1y1, x2y2, x1y2, x2y1 198 | xy = (xy @ M.T)[:, :2].reshape(1, 8) 199 | 200 | # create new boxes 201 | x = xy[:, [0, 2, 4, 6]] 202 | y = xy[:, [1, 3, 5, 7]] 203 | xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, 1).T 204 | 205 | # apply angle-based reduction 206 | radians = a * math.pi / 180 207 | reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 208 | x = (xy[:, 2] + xy[:, 0]) / 2 209 | y = (xy[:, 3] + xy[:, 1]) / 2 210 | w = (xy[:, 2] - xy[:, 0]) * reduction 211 | h = (xy[:, 3] - xy[:, 1]) * reduction 212 | xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, 1).T 213 | 214 | # reject warped points outside of image 215 | np.clip(xy, 0, height, out=xy) 216 | w = xy[:, 2] - xy[:, 0] 217 | h = xy[:, 3] - xy[:, 1] 218 | area = w * h 219 | ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) 220 | i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10) 221 | 222 | ## print(targets, xy) 223 | ## [ 56 36 108 210] [[ 47.80464857 15.6096533 106.30993434 196.71267693]] 224 | # targets = targets[i] 225 | # targets[:, 1:5] = xy[i] 226 | targets = xy[0] 227 | return targets 228 | 229 | 230 | def random_crop(img, seg, pad, h, w): 231 | if random.random() < 0.5: 232 | return img, seg 233 | 234 | img = cv2.copyMakeBorder(img, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=(123.7, 116.3, 103.5)) 235 | seg = cv2.copyMakeBorder(seg, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=(0, 0, 0)) 236 | 237 | Left = random.randint(0, pad * 2) 238 | Top = random.randint(0, pad * 2) 239 | 240 | seg_pixel = seg.sum() 241 | 242 | for _ in range(100): 243 | if seg[Top: Top + h, Left: Left + w].sum() / seg_pixel > 0.95 and seg[Top: Top + h, Left: Left + w].sum() > 0: 244 | img = img[Top: Top + h, Left: Left + w, :] 245 | seg = seg[Top: Top + h, Left: Left + w] 246 | 247 | return img, seg 248 | 249 | Left = random.randint(0, pad * 2) 250 | Top = random.randint(0, pad * 2) 251 | 252 | return img, seg 253 | 254 | 255 | def random_copy(img, seg, phrase, bbox): 256 | if 'left' in phrase or 'right' in phrase or \ 257 | 'center' in phrase or 'middle' in phrase or \ 258 | 'front' in phrase or 'back' in phrase: 259 | return img, seg, phrase, bbox 260 | 261 | if random.random() < 0.75: 262 | return img, seg, phrase, bbox 263 | 264 | h, w = img.shape[0], img.shape[1] 265 | 266 | # x1, y1, x2, y2 = w, h, 0, 0 267 | # for j in range(h): 268 | # for i in range(w): 269 | # if seg[j, i] > 0: 270 | # if i < x1: x1 = i 271 | # if j < y1: y1 = j 272 | # if i > x2: x2 = i 273 | # if j > y2: y2 = j 274 | # x2 = x2 + 1 275 | # y2 = y2 + 1 276 | 277 | # contours, hierarchy = cv2.findContours(seg.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 278 | # c = max(contours, key = cv2.contourArea) 279 | x, y, bboxw, bboxh = cv2.boundingRect(seg.astype(np.uint8)) 280 | x1 = x 281 | y1 = y 282 | x2 = x + bboxw 283 | y2 = y + bboxh 284 | 285 | if x1 - (x2 - x1) < 0 or w - (x2 - x1) < x2: 286 | return img, seg, phrase, bbox 287 | 288 | # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 289 | # color_mask = np.array([0, 255, 0], dtype=np.uint8) 290 | # mask = seg.astype(np.bool) 291 | # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5 292 | # cv2.imwrite('./{}.png'.format(phrase.replace(' ', '_')), tmp) 293 | 294 | if random.random() < 0.5: 295 | new_x1 = random.randint(0, x1 - (x2 - x1)) 296 | phrase += ' on left' 297 | else: 298 | new_x1 = random.randint(x2, w - (x2 - x1)) 299 | phrase += ' on right' 300 | 301 | new_x2 = new_x1 + (x2 - x1) 302 | 303 | delta_y = random.randint((y1 - y2), y2 - y1) 304 | 305 | while y2 + delta_y > h or y1 + delta_y < 0: 306 | delta_y = random.randint((y1 - y2), y2 - y1) 307 | 308 | new_y1 = y1 + delta_y 309 | new_y2 = y2 + delta_y 310 | 311 | new_seg = np.zeros_like(seg) 312 | new_seg[new_y1: new_y2, new_x1: new_x2] = seg[y1: y2, x1: x2] 313 | 314 | # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 315 | # color_mask = np.array([0, 255, 0], dtype=np.uint8) 316 | # mask = new_seg.astype(np.bool) 317 | # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5 318 | # cv2.imwrite('./{}.png'.format(phrase.replace(' ', '_')), tmp) 319 | 320 | img[new_seg.astype(np.bool)] = img[seg.astype(np.bool)] 321 | # bbox = [new_x1, new_y1, new_x2 - 1, new_y2 - 1] 322 | seg = new_seg 323 | 324 | # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 325 | # color_mask = np.array([0, 255, 0], dtype=np.uint8) 326 | # mask = seg.astype(np.bool) 327 | # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5 328 | # cv2.imwrite('./{}.png'.format(phrase.replace(' ', '_')), tmp) 329 | 330 | # exit() 331 | 332 | return img, seg, phrase, bbox 333 | 334 | 335 | def random_erase(img, seg): 336 | if random.random() < 0.5: 337 | return img, seg 338 | 339 | x, y, bboxw, bboxh = cv2.boundingRect(seg.astype(np.uint8)) 340 | 341 | area = bboxw * bboxh * 0.5 342 | 343 | for attempt in range(100): 344 | target_area = random.uniform(0.02, 0.4) 345 | aspect_ratio = random.uniform(0.3, 1/0.3) 346 | 347 | h = int(round(math.sqrt(target_area * aspect_ratio))) 348 | w = int(round(math.sqrt(target_area / aspect_ratio))) 349 | 350 | if w < bboxw and h < bboxh: 351 | x1 = random.randint(0, bboxw - w) 352 | y1 = random.randint(0, bboxh - h) 353 | 354 | new_seg = seg.copy() 355 | new_seg[y+y1: y+y1+h, x+x1: x+x1+w] = 0 356 | 357 | if new_seg.sum() / seg.sum() > 0.75: 358 | continue 359 | 360 | seg[y+y1: y+y1+h, x+x1: x+x1+w] = 0 361 | 362 | img[y+y1: y+y1+h, x+x1: x+x1+w, 0] = 123.7 363 | img[y+y1: y+y1+h, x+x1: x+x1+w, 1] = 116.3 364 | img[y+y1: y+y1+h, x+x1: x+x1+w, 2] = 103.5 365 | 366 | # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 367 | # color_mask = np.array([0, 255, 0], dtype=np.uint8) 368 | # mask = seg.astype(np.bool) 369 | # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5 370 | # cv2.imwrite('./erase.png', tmp) 371 | 372 | return img, seg 373 | 374 | return img, seg -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import optim 6 | from torch.optim import Optimizer 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def xyxy2xywh(x): # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h] 26 | y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape) 27 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 28 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 29 | y[:, 2] = x[:, 2] - x[:, 0] 30 | y[:, 3] = x[:, 3] - x[:, 1] 31 | return y 32 | 33 | 34 | def xywh2xyxy(x): # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2] 35 | y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape) 36 | y[:, 0] = (x[:, 0] - x[:, 2] / 2) 37 | y[:, 1] = (x[:, 1] - x[:, 3] / 2) 38 | y[:, 2] = (x[:, 0] + x[:, 2] / 2) 39 | y[:, 3] = (x[:, 1] + x[:, 3] / 2) 40 | return y 41 | 42 | def bbox_iou_numpy(box1, box2): 43 | """Computes IoU between bounding boxes. 44 | Parameters 45 | ---------- 46 | box1 : ndarray 47 | (N, 4) shaped array with bboxes 48 | box2 : ndarray 49 | (M, 4) shaped array with bboxes 50 | Returns 51 | ------- 52 | : ndarray 53 | (N, M) shaped array with IoUs 54 | """ 55 | area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) 56 | 57 | iw = np.minimum(np.expand_dims(box1[:, 2], axis=1), box2[:, 2]) - np.maximum( 58 | np.expand_dims(box1[:, 0], 1), box2[:, 0] 59 | ) 60 | ih = np.minimum(np.expand_dims(box1[:, 3], axis=1), box2[:, 3]) - np.maximum( 61 | np.expand_dims(box1[:, 1], 1), box2[:, 1] 62 | ) 63 | 64 | iw = np.maximum(iw, 0) 65 | ih = np.maximum(ih, 0) 66 | 67 | ua = np.expand_dims((box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]), axis=1) + area - iw * ih 68 | 69 | ua = np.maximum(ua, np.finfo(float).eps) 70 | 71 | intersection = iw * ih 72 | 73 | return intersection / ua 74 | 75 | 76 | def bbox_iou(box1, box2, x1y1x2y2=True): 77 | """ 78 | Returns the IoU of two bounding boxes 79 | """ 80 | if x1y1x2y2: 81 | # Get the coordinates of bounding boxes 82 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] 83 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] 84 | else: 85 | # Transform from center and width to exact coordinates 86 | b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 87 | b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 88 | b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 89 | b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 90 | 91 | # get the coordinates of the intersection rectangle 92 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 93 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 94 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 95 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 96 | # Intersection area 97 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, 0) * torch.clamp(inter_rect_y2 - inter_rect_y1, 0) 98 | # Union Area 99 | b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) 100 | b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) 101 | 102 | # print(box1, box1.shape) 103 | # print(box2, box2.shape) 104 | return inter_area / (b1_area + b2_area - inter_area + 1e-16) 105 | 106 | def multiclass_metrics(pred, gt): 107 | """ 108 | check precision and recall for predictions. 109 | Output: overall = {precision, recall, f1} 110 | """ 111 | eps=1e-6 112 | overall = {'precision': -1, 'recall': -1, 'f1': -1} 113 | NP, NR, NC = 0, 0, 0 # num of pred, num of recall, num of correct 114 | for ii in range(pred.shape[0]): 115 | pred_ind = np.array(pred[ii]>0.5, dtype=int) 116 | gt_ind = np.array(gt[ii]>0.5, dtype=int) 117 | inter = pred_ind * gt_ind 118 | # add to overall 119 | NC += np.sum(inter) 120 | NP += np.sum(pred_ind) 121 | NR += np.sum(gt_ind) 122 | if NP > 0: 123 | overall['precision'] = float(NC)/NP 124 | if NR > 0: 125 | overall['recall'] = float(NC)/NR 126 | if NP > 0 and NR > 0: 127 | overall['f1'] = 2*overall['precision']*overall['recall']/(overall['precision']+overall['recall']+eps) 128 | return overall 129 | 130 | def compute_ap(recall, precision): 131 | """ Compute the average precision, given the recall and precision curves. 132 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 133 | # Arguments 134 | recall: The recall curve (list). 135 | precision: The precision curve (list). 136 | # Returns 137 | The average precision as computed in py-faster-rcnn. 138 | """ 139 | # correct AP calculation 140 | # first append sentinel values at the end 141 | mrec = np.concatenate(([0.0], recall, [1.0])) 142 | mpre = np.concatenate(([0.0], precision, [0.0])) 143 | 144 | # compute the precision envelope 145 | for i in range(mpre.size - 1, 0, -1): 146 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 147 | 148 | # to calculate area under PR curve, look for points 149 | # where X axis (recall) changes value 150 | i = np.where(mrec[1:] != mrec[:-1])[0] 151 | 152 | # and sum (\Delta recall) * prec 153 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 154 | return ap 155 | 156 | def concat_coord(x): 157 | ins_feat = x # [bt, c, h, w] [512, 26, 26] 158 | batch_size, c, h, w = x.size() 159 | 160 | float_h = float(h) 161 | float_w = float(w) 162 | 163 | y_range = torch.arange(0., float_h, dtype=torch.float32) # [h, ] 164 | y_range = 2.0 * y_range / (float_h - 1.0) - 1.0 165 | x_range = torch.arange(0., float_w, dtype=torch.float32) # [w, ] 166 | x_range = 2.0 * x_range / (float_w - 1.0) - 1.0 167 | x_range = x_range[None, :] # [1, w] 168 | y_range = y_range[:, None] # [h, 1] 169 | x = x_range.repeat(h, 1) # [h, w] 170 | y = y_range.repeat(1, w) # [h, w] 171 | 172 | x = x[None, None, :, :] # [1, 1, h, w] 173 | y = y[None, None, :, :] # [1, 1, h, w] 174 | x = x.repeat(batch_size, 1, 1, 1) # [N, 1, h, w] 175 | y = y.repeat(batch_size, 1, 1, 1) # [N, 1, h, w] 176 | x = x.cuda() 177 | y = y.cuda() 178 | 179 | ins_feat_out = torch.cat((ins_feat, x, x, x, y, y, y), 1) # [N, c+6, h, w] 180 | 181 | return ins_feat_out 182 | 183 | 184 | def get_cosine_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 185 | num_cycles: float = 0.5, last_epoch: int = -1): 186 | """ 187 | Implementation by Huggingface: 188 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py 189 | 190 | Create a schedule with a learning rate that decreases following the values 191 | of the cosine function between the initial lr set in the optimizer to 0, 192 | after a warmup period during which it increases linearly between 0 and the 193 | initial lr set in the optimizer. 194 | Args: 195 | optimizer ([`~torch.optim.Optimizer`]): 196 | The optimizer for which to schedule the learning rate. 197 | num_warmup_steps (`int`): 198 | The number of steps for the warmup phase. 199 | num_training_steps (`int`): 200 | The total number of training steps. 201 | num_cycles (`float`, *optional*, defaults to 0.5): 202 | The number of waves in the cosine schedule (the defaults is to just 203 | decrease from the max value to 0 following a half-cosine). 204 | last_epoch (`int`, *optional*, defaults to -1): 205 | The index of the last epoch when resuming training. 206 | Return: 207 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 208 | """ 209 | 210 | def lr_lambda(current_step): 211 | if current_step < num_warmup_steps: 212 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps))) 213 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 214 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 215 | 216 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 217 | 218 | def dice_loss(inputs, targets): 219 | """ 220 | Compute the DICE loss, similar to generalized IOU for masks 221 | Args: 222 | inputs: A float tensor of arbitrary shape. 223 | The predictions for each example. 224 | targets: A float tensor with the same shape as inputs. Stores the binary 225 | classification label for each element in inputs 226 | (0 for the negative class and 1 for the positive class). 227 | """ 228 | 229 | inputs = inputs.sigmoid() 230 | inputs = inputs.flatten(1) 231 | targets = targets.flatten(1) 232 | numerator = 2 * (inputs * targets).sum(1) 233 | denominator = inputs.sum(-1) + targets.sum(-1) 234 | loss = 1 - (numerator + 1) / (denominator + 1) 235 | return loss.mean() 236 | 237 | def sigmoid_focal_loss(inputs, targets, alpha: float = -1, gamma: float = 0): 238 | """ 239 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 240 | Args: 241 | inputs: A float tensor of arbitrary shape. 242 | The predictions for each example. 243 | targets: A float tensor with the same shape as inputs. Stores the binary 244 | classification label for each element in inputs 245 | (0 for the negative class and 1 for the positive class). 246 | alpha: (optional) Weighting factor in range (0,1) to balance 247 | positive vs negative examples. Default = -1 (no weighting). 248 | gamma: Exponent of the modulating factor (1 - p_t) to 249 | balance easy vs hard examples. 250 | Returns: 251 | Loss tensor 252 | """ 253 | 254 | prob = inputs.sigmoid() 255 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 256 | p_t = prob * targets + (1 - prob) * (1 - targets) 257 | loss = ce_loss * ((1 - p_t) ** gamma) 258 | 259 | if alpha >= 0: 260 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 261 | loss = alpha_t * loss 262 | return loss.mean() --------------------------------------------------------------------------------