├── store └── __init__.py ├── work ├── __init__.py ├── data │ ├── __init__.py │ ├── get_dataloader.py │ └── unified_dataset.py ├── model │ ├── __init__.py │ ├── vgtr │ │ ├── __init__.py │ │ ├── vgtr.py │ │ ├── position_encoding.py │ │ ├── vg_decoder.py │ │ └── vg_encoder.py │ ├── backbone │ │ ├── __init__.py │ │ ├── rnn.py │ │ ├── visual_backbone.py │ │ └── resnet.py │ ├── grounding_model.py │ └── criterion.py ├── utils │ ├── __init__.py │ ├── losses.py │ ├── misc_utils.py │ ├── word_utils.py │ ├── parsing_metrics.py │ ├── utils.py │ └── transforms.py └── engine.py ├── LICENSE ├── download_data.sh ├── README.md └── main.py /store/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /work/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /work/data/__init__.py: -------------------------------------------------------------------------------- 1 | # 1 -------------------------------------------------------------------------------- /work/model/__init__.py: -------------------------------------------------------------------------------- 1 | # 1 -------------------------------------------------------------------------------- /work/model/vgtr/__init__.py: -------------------------------------------------------------------------------- 1 | # 1 -------------------------------------------------------------------------------- /work/model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # 1 -------------------------------------------------------------------------------- /work/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # ----------------------------------------------------------------------------- 3 | # Copyright (c) Edgar Andrés Margffoy-Tuay, Emilio Botero and Juan Camilo Pérez 4 | # 5 | # Licensed under the terms of the MIT License 6 | # (see LICENSE for details) 7 | # ----------------------------------------------------------------------------- 8 | 9 | """Misc data and other helping utillites.""" 10 | 11 | from .word_utils import Corpus 12 | from .transforms import ResizeImage, ResizeAnnotation 13 | 14 | Corpus 15 | ResizeImage 16 | ResizeAnnotation 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 triple6x 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /work/utils/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Custom loss function definitions. 5 | """ 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class IoULoss(nn.Module): 12 | """ 13 | Creates a criterion that computes the Intersection over Union (IoU) 14 | between a segmentation mask and its ground truth. 15 | 16 | Rahman, M.A. and Wang, Y: 17 | Optimizing Intersection-Over-Union in Deep Neural Networks for 18 | Image Segmentation. International Symposium on Visual Computing (2016) 19 | http://www.cs.umanitoba.ca/~ywang/papers/isvc16.pdf 20 | """ 21 | 22 | def __init__(self, size_average=True): 23 | super().__init__() 24 | self.size_average = size_average 25 | 26 | def forward(self, input, target): 27 | input = F.sigmoid(input) 28 | intersection = (input * target).sum() 29 | union = ((input + target) - (input * target)).sum() 30 | iou = intersection / union 31 | iou_dual = input.size(0) - iou 32 | if self.size_average: 33 | iou_dual = iou_dual / input.size(0) 34 | return iou_dual 35 | -------------------------------------------------------------------------------- /work/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Misc download and visualization helper functions and class wrappers. 5 | """ 6 | 7 | import sys 8 | import time 9 | import torch 10 | from visdom import Visdom 11 | 12 | 13 | def reporthook(count, block_size, total_size): 14 | global start_time 15 | if count == 0: 16 | start_time = time.time() 17 | return 18 | duration = time.time() - start_time 19 | progress_size = int(count * block_size) 20 | speed = int(progress_size / (1024 * duration)) 21 | percent = min(int(count * block_size * 100 / total_size), 100) 22 | sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % 23 | (percent, progress_size / (1024 * 1024), speed, duration)) 24 | sys.stdout.flush() 25 | 26 | 27 | class VisdomWrapper(Visdom): 28 | def __init__(self, *args, env=None, **kwargs): 29 | Visdom.__init__(self, *args, **kwargs) 30 | self.env = env 31 | self.plots = {} 32 | 33 | def init_line_plot(self, name, 34 | X=torch.zeros((1,)).cpu(), 35 | Y=torch.zeros((1,)).cpu(), **opts): 36 | self.plots[name] = self.line(X=X, Y=Y, env=self.env, opts=opts) 37 | 38 | def plot_line(self, name, **kwargs): 39 | self.line(win=self.plots[name], env=self.env, **kwargs) 40 | -------------------------------------------------------------------------------- /work/model/grounding_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | from .backbone.visual_backbone import build_visual_backbone 5 | from .backbone.rnn import build_textual_encoder 6 | from .vgtr.vgtr import build_vgtr 7 | 8 | 9 | class GroundingModel(nn.Module): 10 | 11 | def __init__(self, args): 12 | super().__init__() 13 | 14 | self.visual_encoder = build_visual_backbone(args) 15 | self.textual_encoder = build_textual_encoder(args) 16 | self.vgtr = build_vgtr(args) 17 | self.num_exp_tokens = args.num_exp_tokens 18 | self.prediction_head = nn.Sequential( 19 | nn.Linear(args.hidden_dim * args.num_exp_tokens, args.hidden_dim), 20 | nn.BatchNorm1d(args.hidden_dim), 21 | nn.ReLU(), 22 | nn.Linear(args.hidden_dim, args.hidden_dim), 23 | nn.BatchNorm1d(args.hidden_dim), 24 | nn.ReLU(), 25 | nn.Linear(args.hidden_dim, 4) 26 | ) 27 | 28 | def forward(self, img, expression_word_id): 29 | 30 | img_feature = self.visual_encoder(img) 31 | exp_feature = self.textual_encoder(expression_word_id) 32 | 33 | embed = self.vgtr(img_feature, exp_feature, expression_word_id) 34 | embed2 = torch.cat([embed[:, i] for i in range(self.num_exp_tokens)], dim=-1) 35 | 36 | pred = self.prediction_head(embed2).sigmoid() 37 | 38 | return pred 39 | -------------------------------------------------------------------------------- /work/model/vgtr/vgtr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .vg_encoder import VGEncoder 3 | from .vg_decoder import VGDecoder 4 | from .position_encoding import PositionEmbeddingSine 5 | 6 | 7 | class VGTR(nn.Module): 8 | 9 | def __init__(self, args): 10 | super().__init__() 11 | 12 | self.input_proj = nn.Conv2d(2048, args.hidden_dim, kernel_size=1) 13 | 14 | self.encoder = VGEncoder( 15 | d_model=args.hidden_dim, 16 | dropout=args.dropout, 17 | nhead=args.nheads, 18 | dim_feedforward=args.dim_feedforward, 19 | num_encoder_layers=args.enc_layers) 20 | 21 | self.decoder = VGDecoder(n_layers=args.dec_layers, 22 | n_heads=args.nheads, 23 | d_model=args.hidden_dim) 24 | 25 | self.pos_encoder = PositionEmbeddingSine(args.hidden_dim // 2, normalize=False) 26 | 27 | def forward(self, img, sent, sent_id): 28 | 29 | pos_feature = self.pos_encoder(img) 30 | 31 | # encoder 32 | fused_vis_feature, fused_exp_feature = self.encoder(self.input_proj(img), pos_feature, sent) 33 | 34 | # decoder 35 | out = self.decoder(fused_vis_feature.transpose(0, 1), fused_exp_feature, 36 | pos_feature=pos_feature.flatten(2).permute(2, 0, 1)) 37 | 38 | return out.transpose(0, 1) 39 | 40 | 41 | def build_vgtr(args): 42 | 43 | return VGTR(args) -------------------------------------------------------------------------------- /work/model/vgtr/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 15 | super().__init__() 16 | 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize # default True 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi # 2pi 24 | self.scale = scale 25 | 26 | def forward(self, tensor_list): 27 | 28 | x = tensor_list # (b, c, h, w) 29 | 30 | not_mask = torch.ones((x.shape[0], x.shape[-2], x.shape[-1])).cuda() # (b, h, w) 31 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 32 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # (b, 128, h, w) 46 | return pos 47 | 48 | 49 | class PositionEncoding1D(nn.Module): 50 | def __init__(self, d_model=256, max_len=20): 51 | super().__init__() 52 | pe = torch.zeros(max_len, d_model) 53 | po = torch.arange(max_len).unsqueeze(1) 54 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000) / d_model)) 55 | pe[:, 0::2] = torch.sin(po * div_term) 56 | pe[:, 1::2] = torch.cos(po * div_term) 57 | self.register_buffer('pe', pe) 58 | def forward(self, x): 59 | l, *_ =x.shape 60 | return self.pe[:l, :].unsqueeze(1) 61 | -------------------------------------------------------------------------------- /work/data/get_dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .unified_dataset import UnifiedDataset 3 | from torchvision.transforms import Compose, ToTensor, Normalize 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | input_transform = Compose([ 8 | ToTensor(), 9 | Normalize( 10 | mean=[0.485, 0.456, 0.406], 11 | std=[0.229, 0.224, 0.225]) 12 | ]) 13 | 14 | 15 | def get_train_loader(args): 16 | 17 | train_dataset = UnifiedDataset(data_root=args.data_root, 18 | split_root=args.split_root, 19 | dataset=args.dataset, 20 | split='train', 21 | imsize=args.size, 22 | transform=input_transform, 23 | max_query_len=args.max_query_len, 24 | augment=True) 25 | args.vocab_size = len(train_dataset.corpus) 26 | 27 | return DataLoader(train_dataset, batch_size=args.batch_size, 28 | shuffle=True, pin_memory=True, drop_last=True, 29 | num_workers=args.workers) 30 | 31 | 32 | def get_val_loader(args): 33 | 34 | val_dataset = UnifiedDataset(data_root=args.data_root, 35 | split_root=args.split_root, 36 | dataset=args.dataset, 37 | split='val', 38 | imsize=args.size, 39 | transform=input_transform, 40 | max_query_len=args.max_query_len) 41 | 42 | return DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 43 | pin_memory=True, drop_last=True, num_workers=args.workers) 44 | 45 | 46 | def get_test_loader(args, split): 47 | 48 | if args.dataset == 'refcoco' or args.dataset == 'refcoco+': 49 | assert split == 'testA' or split == 'testB' 50 | elif args.dataset == 'refcocog': 51 | assert split == 'val' 52 | else: 53 | assert split == 'test' 54 | 55 | test_dataset = UnifiedDataset(data_root=args.data_root, 56 | split_root=args.split_root, 57 | dataset=args.dataset, 58 | testmode=True, 59 | split=split, 60 | imsize=args.size, 61 | transform=input_transform, 62 | max_query_len=args.max_query_len) 63 | args.vocab_size = len(test_dataset.corpus) 64 | 65 | return DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, 66 | pin_memory=True, drop_last=False, num_workers=args.workers) 67 | -------------------------------------------------------------------------------- /work/utils/word_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Language-related data loading helper functions and class wrappers. 5 | """ 6 | 7 | import re 8 | import torch 9 | import codecs 10 | 11 | UNK_TOKEN = '' 12 | PAD_TOKEN = '' 13 | END_TOKEN = '' 14 | SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') 15 | 16 | 17 | class Dictionary(object): 18 | def __init__(self): 19 | self.word2idx = {} 20 | self.idx2word = [] 21 | 22 | def add_word(self, word): 23 | if word not in self.word2idx: 24 | self.idx2word.append(word) 25 | self.word2idx[word] = len(self.idx2word) - 1 26 | return self.word2idx[word] 27 | 28 | def __len__(self): 29 | return len(self.idx2word) 30 | 31 | def __getitem__(self, a): 32 | if isinstance(a, int): 33 | return self.idx2word[a] 34 | elif isinstance(a, list): 35 | return [self.idx2word[x] for x in a] 36 | elif isinstance(a, str): 37 | return self.word2idx[a] 38 | else: 39 | raise TypeError("Query word/index argument must be int or str") 40 | 41 | def __contains__(self, word): 42 | return word in self.word2idx 43 | 44 | 45 | class Corpus(object): 46 | def __init__(self): 47 | self.dictionary = Dictionary() 48 | 49 | def set_max_len(self, value): 50 | self.max_len = value 51 | 52 | def load_file(self, filename): 53 | with codecs.open(filename, 'r', 'utf-8') as f: 54 | for line in f: 55 | line = line.strip() 56 | self.add_to_corpus(line) 57 | self.dictionary.add_word(UNK_TOKEN) 58 | self.dictionary.add_word(PAD_TOKEN) 59 | 60 | def add_to_corpus(self, line): 61 | """Tokenizes a text line.""" 62 | # Add words to the dictionary 63 | words = line.split() 64 | # tokens = len(words) 65 | for word in words: 66 | word = word.lower() 67 | self.dictionary.add_word(word) 68 | 69 | def tokenize(self, line, max_len=20): 70 | # Tokenize line contents 71 | words = SENTENCE_SPLIT_REGEX.split(line.strip()) 72 | # words = [w.lower() for w in words if len(w) > 0] 73 | words = [w.lower() for w in words if (len(w) > 0 and w!=' ')] 74 | ## do not include space as a token 75 | 76 | if words[-1] == '.': 77 | words = words[:-1] 78 | 79 | if max_len > 0: 80 | if len(words) > max_len: 81 | words = words[:max_len] 82 | elif len(words) < max_len: 83 | # words = [PAD_TOKEN] * (max_len - len(words)) + words 84 | words = words + [END_TOKEN] + [PAD_TOKEN] * (max_len - len(words) - 1) 85 | 86 | tokens = len(words) ## for end token 87 | ids = torch.LongTensor(tokens) 88 | token = 0 89 | for word in words: 90 | if word not in self.dictionary: 91 | word = UNK_TOKEN 92 | # print(word, type(word), word.encode('ascii','ignore').decode('ascii'), type(word.encode('ascii','ignore').decode('ascii'))) 93 | if type(word)!=type('a'): 94 | print(word, type(word), word.encode('ascii','ignore').decode('ascii'), type(word.encode('ascii','ignore').decode('ascii'))) 95 | word = word.encode('ascii','ignore').decode('ascii') 96 | ids[token] = self.dictionary[word] 97 | token += 1 98 | # ids[token] = self.dictionary[END_TOKEN] 99 | return ids 100 | 101 | def __len__(self): 102 | return len(self.dictionary) 103 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | version="0.1" 4 | 5 | # This is an optional arguments-only example of Argbash potential 6 | # 7 | # ARG_OPTIONAL_SINGLE([path],[p],[path onto which files are to be downloaded],[data]) 8 | # ARG_VERSION([echo test v$version]) 9 | # ARG_HELP([The general script's help msg]) 10 | # ARGBASH_GO() 11 | # needed because of Argbash --> m4_ignore([ 12 | ### START OF CODE GENERATED BY Argbash v2.5.0 one line above ### 13 | # Argbash is a bash code generator used to get arguments parsing right. 14 | # Argbash is FREE SOFTWARE, see https://argbash.io for more info 15 | # Generated online by https://argbash.io/generate 16 | 17 | die() 18 | { 19 | local _ret=$2 20 | test -n "$_ret" || _ret=1 21 | test "$_PRINT_HELP" = yes && print_help >&2 22 | echo "$1" >&2 23 | exit ${_ret} 24 | } 25 | 26 | begins_with_short_option() 27 | { 28 | local first_option all_short_options 29 | all_short_options='pvh' 30 | first_option="${1:0:1}" 31 | test "$all_short_options" = "${all_short_options/$first_option/}" && return 1 || return 0 32 | } 33 | 34 | 35 | 36 | # THE DEFAULTS INITIALIZATION - OPTIONALS 37 | _arg_path="referit_data" 38 | 39 | print_help () 40 | { 41 | printf "%s\n" "download ReferIt data script" 42 | printf 'Usage: %s [-p|--path ] [-v|--version] [-h|--help]\n' "$0" 43 | printf "\t%s\n" "-p,--path: path onto which files are to be downloaded (default: '"referit_data"')" 44 | printf "\t%s\n" "-v,--version: Prints version" 45 | printf "\t%s\n" "-h,--help: Prints help" 46 | } 47 | 48 | parse_commandline () 49 | { 50 | while test $# -gt 0 51 | do 52 | _key="$1" 53 | case "$_key" in 54 | -p|--path) 55 | test $# -lt 2 && die "Missing value for the optional argument '$_key'." 1 56 | _arg_path="$2" 57 | shift 58 | ;; 59 | --path=*) 60 | _arg_path="${_key##--path=}" 61 | ;; 62 | -p*) 63 | _arg_path="${_key##-p}" 64 | ;; 65 | -v|--version) 66 | echo test v$version 67 | exit 0 68 | ;; 69 | -v*) 70 | echo test v$version 71 | exit 0 72 | ;; 73 | -h|--help) 74 | print_help 75 | exit 0 76 | ;; 77 | -h*) 78 | print_help 79 | exit 0 80 | ;; 81 | *) 82 | _PRINT_HELP=yes die "FATAL ERROR: Got an unexpected argument '$1'" 1 83 | ;; 84 | esac 85 | shift 86 | done 87 | } 88 | 89 | parse_commandline "$@" 90 | 91 | # OTHER STUFF GENERATED BY Argbash 92 | 93 | ### END OF CODE GENERATED BY Argbash (sortof) ### ]) 94 | # [ <-- needed because of Argbash 95 | 96 | 97 | echo "Save data to: $_arg_path" 98 | 99 | 100 | REFERIT_SPLITS_URL="https://s3-sa-east-1.amazonaws.com/query-objseg/referit_splits.tar.bz2" 101 | REFERIT_DATA_URL="http://www.eecs.berkeley.edu/~ronghang/projects/cvpr16_text_obj_retrieval/referitdata.tar.gz" 102 | COCO_DATA_URL="http://images.cocodataset.org/zips/train2014.zip" 103 | 104 | REFCOCO_URL="http://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip" 105 | REFCOCO_PLUS_URL="http://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip" 106 | REFCOCOG_URL="http://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip" 107 | 108 | REFERIT_FILE=${REFERIT_DATA_URL#*cvpr16_text_obj_retrieval/} 109 | SPLIT_FILE=${REFERIT_SPLITS_URL#*query-objseg/} 110 | COCO_FILE=${COCO_DATA_URL#*zips/} 111 | 112 | 113 | if [ ! -d $_arg_path ]; then 114 | mkdir $_arg_path 115 | cd $_arg_path 116 | 117 | mkdir referit 118 | cd referit 119 | 120 | printf "Downloading ReferIt dataset (This may take a while...)" 121 | aria2c -x 8 $REFERIT_DATA_URL 122 | 123 | 124 | printf "Uncompressing data..." 125 | tar -xzvf $REFERIT_FILE 126 | rm $REFERIT_FILE 127 | 128 | mkdir splits 129 | cd splits 130 | 131 | printf "Downloading ReferIt Splits..." 132 | aria2c -x 8 $REFERIT_SPLITS_URL 133 | 134 | tar -xjvf $SPLIT_FILE 135 | rm $SPLIT_FILE 136 | 137 | cd ../.. 138 | 139 | mkdir -p other/images/mscoco/images 140 | cd other/images/mscoco/images 141 | 142 | printf "Downloading MS COCO 2014 train images (This may take a while...)" 143 | aria2c -x 8 $COCO_DATA_URL 144 | 145 | unzip $COCO_FILE 146 | rm $COCO_FILE 147 | 148 | cd ../../.. 149 | printf "Downloading refcoco, refcocog and refcoco+ splits..." 150 | aria2c -x 8 $REFCOCO_URL 151 | aria2c -x 8 $REFCOCO_PLUS_URL 152 | aria2c -x 8 $REFCOCOG_URL 153 | 154 | unzip "*.zip" 155 | rm *.zip 156 | fi 157 | -------------------------------------------------------------------------------- /work/model/vgtr/vg_decoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | 8 | 9 | class VGDecoder(nn.Module): 10 | def __init__(self, d_model=256, n_heads=8, n_layers=2, dim_feedforward=2048): 11 | super().__init__() 12 | 13 | decoder_layer = DecoderLayer(d_model=d_model, nhead=n_heads, 14 | dim_feedforward=dim_feedforward, 15 | dropout=0.1, 16 | activation="relu") 17 | 18 | norm = nn.LayerNorm(d_model) 19 | self.decoder = Decoder(decoder_layer, num_layers=n_layers, norm=norm) 20 | 21 | self._reset_parameters() 22 | 23 | def _reset_parameters(self): 24 | for p in self.parameters(): 25 | if p.dim() > 1: 26 | nn.init.xavier_uniform_(p) 27 | 28 | def forward(self, fused_vis_f, fused_exp_f, pos_feature=None): 29 | 30 | expression_feature_token = fused_exp_f.transpose(0, 1) 31 | 32 | return self.decoder(expression_feature_token, fused_vis_f, pos=pos_feature) 33 | 34 | def _get_attn_pad_mask(self, seq_q, len_k): 35 | 36 | batch_size, len_q = seq_q.size() 37 | pad_attn_mask = seq_q.data.eq(0).unsqueeze(1) 38 | return pad_attn_mask.expand(batch_size, len_k, len_q) # [batch_size, len_q, len_k] 39 | 40 | 41 | class Decoder(nn.Module): 42 | 43 | def __init__(self, decoder_layer, num_layers, norm=None, 44 | return_intermediate=False): 45 | super().__init__() 46 | self.layers = _get_clones(decoder_layer, num_layers) 47 | self.num_layers = num_layers 48 | self.norm = norm 49 | self.return_intermediate = return_intermediate 50 | 51 | def forward(self, tgt, memory, 52 | pos: Optional[Tensor] = None, 53 | query_pos: Optional[Tensor] = None): 54 | output = tgt 55 | 56 | intermediate = [] 57 | 58 | for layer in self.layers: 59 | output = layer(output, memory, pos=pos, query_pos=None) 60 | if self.return_intermediate: 61 | intermediate.append(self.norm(output)) 62 | if self.norm is not None: 63 | output = self.norm(output) 64 | if self.return_intermediate: 65 | intermediate.pop() 66 | intermediate.append(output) 67 | 68 | if self.return_intermediate: 69 | return torch.stack(intermediate) 70 | 71 | return output 72 | 73 | 74 | class DecoderLayer(nn.Module): 75 | 76 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 77 | activation="relu"): 78 | super().__init__() 79 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 80 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 81 | self.linear1 = nn.Linear(d_model, dim_feedforward) 82 | self.dropout = nn.Dropout(dropout) 83 | self.linear2 = nn.Linear(dim_feedforward, d_model) 84 | 85 | self.norm1 = nn.LayerNorm(d_model) 86 | self.norm2 = nn.LayerNorm(d_model) 87 | self.norm3 = nn.LayerNorm(d_model) 88 | self.dropout1 = nn.Dropout(dropout) 89 | self.dropout2 = nn.Dropout(dropout) 90 | self.dropout3 = nn.Dropout(dropout) 91 | 92 | self.activation = _get_activation_fn(activation) 93 | 94 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 95 | return tensor if pos is None else tensor + pos 96 | 97 | def forward(self, tgt, memory, 98 | pos: Optional[Tensor] = None, 99 | query_pos: Optional[Tensor] = None): 100 | tgt2 = self.norm1(tgt) 101 | q = k = self.with_pos_embed(tgt2, query_pos) 102 | tgt2 = self.self_attn(q, k, value=tgt2)[0] 103 | tgt = tgt + self.dropout1(tgt2) 104 | tgt2 = self.norm2(tgt) 105 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 106 | key=self.with_pos_embed(memory, pos), 107 | value=memory)[0] 108 | tgt = tgt + self.dropout2(tgt2) 109 | tgt2 = self.norm3(tgt) 110 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 111 | tgt = tgt + self.dropout3(tgt2) 112 | return tgt 113 | 114 | 115 | def _get_clones(module, N): 116 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 117 | 118 | 119 | def _get_activation_fn(activation): 120 | if activation == "relu": 121 | return F.relu 122 | if activation == "gelu": 123 | return F.gelu 124 | if activation == "glu": 125 | return F.glu 126 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 127 | -------------------------------------------------------------------------------- /work/model/backbone/rnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class RNNEncoder(nn.Module): 13 | 14 | def __init__(self, vocab_size, word_embedding_size, word_vec_size, hidden_size, bidirectional=False, 15 | input_dropout_p=0., dropout_p=0., n_layers=2, rnn_type='lstm', variable_lengths=True): 16 | super(RNNEncoder, self).__init__() 17 | self.variable_lengths = variable_lengths 18 | self.embedding = nn.Embedding(vocab_size, word_embedding_size) 19 | self.input_dropout = nn.Dropout(input_dropout_p) 20 | self.mlp = nn.Sequential(nn.Linear(word_embedding_size, word_vec_size), 21 | nn.ReLU()) 22 | self.rnn_type = rnn_type 23 | self.rnn = getattr(nn, rnn_type.upper())(word_vec_size, hidden_size, n_layers, 24 | batch_first=True, 25 | bidirectional=bidirectional, 26 | dropout=dropout_p) 27 | self.num_dirs = 2 if bidirectional else 1 28 | # self._init_param() 29 | 30 | def _init_param(self): 31 | for k, v in self.rnn.named_parameters(): 32 | if 'bias' in k: 33 | v.data.zero_().add_(1.0) # init LSTM bias = 1.0 34 | 35 | def forward(self, input_labels): 36 | 37 | if self.variable_lengths: 38 | input_lengths = (input_labels != 0).sum(1) 39 | # make ixs 40 | input_lengths_list = input_lengths.data.cpu().numpy().tolist() 41 | sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist() 42 | sort_ixs = np.argsort(input_lengths_list)[::-1].tolist() 43 | s2r = {s: r for r, s in enumerate(sort_ixs)} # O(n) 44 | recover_ixs = [s2r[s] for s in range(len(input_lengths_list))] 45 | assert max(input_lengths_list) == input_labels.size(1) 46 | 47 | sort_ixs = input_labels.data.new(sort_ixs).long() # Variable long 48 | recover_ixs = input_labels.data.new(recover_ixs).long() # Variable long 49 | 50 | input_labels = input_labels[sort_ixs] 51 | 52 | embedded = self.embedding(input_labels) 53 | embedded = self.input_dropout(embedded) 54 | embedded = self.mlp(embedded) 55 | if self.variable_lengths: 56 | embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_input_lengths_list, batch_first=True) 57 | 58 | output, hidden = self.rnn(embedded) 59 | 60 | if self.variable_lengths: 61 | 62 | embedded, _ = nn.utils.rnn.pad_packed_sequence(embedded, batch_first=True) 63 | embedded = embedded[recover_ixs] 64 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # (batch, max_len, hidden) 65 | output = output[recover_ixs] 66 | 67 | if self.rnn_type == 'lstm': 68 | hidden = hidden[0] 69 | hidden = hidden[:, recover_ixs, :] 70 | hidden = hidden.transpose(0, 1).contiguous() 71 | hidden = hidden.view(hidden.size(0), -1) 72 | 73 | return output, hidden, embedded 74 | 75 | 76 | 77 | class PhraseAttention(nn.Module): 78 | def __init__(self, input_dim): 79 | super(PhraseAttention, self).__init__() 80 | self.fc = nn.Linear(input_dim, 1) 81 | 82 | def forward(self, context, embedded, input_labels): 83 | 84 | cxt_scores = self.fc(context).squeeze(2) 85 | attn = F.softmax(cxt_scores) 86 | 87 | is_not_zero = (input_labels != 0).float() 88 | attn = attn * is_not_zero 89 | attn = attn / attn.sum(1).view(attn.size(0), 1).expand(attn.size(0), attn.size(1)) 90 | 91 | # compute weighted embedding 92 | attn3 = attn.unsqueeze(1) 93 | weighted_emb = torch.bmm(attn3, embedded) 94 | weighted_emb = weighted_emb.squeeze(1) 95 | 96 | return attn, weighted_emb 97 | 98 | 99 | class TextualEncoder(nn.Module): 100 | def __init__(self, args): 101 | super().__init__() 102 | self.rnn = RNNEncoder(args.vocab_size, args.embedding_dim, 103 | args.hidden_dim, args.rnn_hidden_dim, 104 | bidirectional=True, 105 | input_dropout_p=0.1, 106 | dropout_p=0.1, 107 | n_layers=args.rnn_layers, 108 | variable_lengths=True, 109 | rnn_type='lstm') 110 | self.parser = nn.ModuleList([PhraseAttention(input_dim=args.rnn_hidden_dim * 2) 111 | for _ in range(args.num_exp_tokens)]) 112 | 113 | def forward(self, sent): 114 | max_len = (sent != 0).sum(1).max().item() 115 | sent = sent[:, :max_len] 116 | context, hidden, embedded = self.rnn(sent) # [bs, maxL, d] 117 | sent_feature = [module(context, embedded, sent)[-1] for module in self.parser] 118 | return torch.stack(sent_feature, dim=1) 119 | 120 | 121 | def build_textual_encoder(args): 122 | return TextualEncoder(args) 123 | 124 | -------------------------------------------------------------------------------- /work/data/unified_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import cv2 5 | import torch 6 | import numpy as np 7 | import os.path as osp 8 | import torch.utils.data as data 9 | from .. import utils 10 | from ..utils import Corpus 11 | from ..utils.transforms import trans, trans_simple 12 | sys.path.append('.') 13 | sys.modules['utils'] = utils 14 | cv2.setNumThreads(0) 15 | 16 | class UnifiedDataset(data.Dataset): 17 | 18 | SUPPORTED_DATASETS = { 19 | 'refcoco': { 20 | 'splits': ('train', 'val', 'trainval', 'testA', 'testB'), 21 | 'params': {'dataset': 'refcoco', 'split_by': 'unc'} 22 | }, 23 | 24 | 'refcoco+': { 25 | 'splits': ('train', 'val', 'trainval', 'testA', 'testB'), 26 | 'params': {'dataset': 'refcoco+', 'split_by': 'unc'} 27 | }, 28 | 29 | 'refcocog': { 30 | 'splits': ('train', 'val'), 31 | 'params': {'dataset': 'refcocog', 'split_by': 'google'} 32 | }, 33 | 34 | 'refcocog_umd': { 35 | 'splits': ('train', 'val', 'test'), 36 | 'params': {'dataset': 'refcocog', 'split_by': 'umd'} 37 | }, 38 | 39 | 'flickr': { 40 | 'splits': ('train', 'val', 'test') 41 | }, 42 | 43 | 'copsref': { 44 | 'splits': ('train', 'val', 'test') 45 | } 46 | } 47 | 48 | # map the dataset name to data folder 49 | MAPPING = { 50 | 'refcoco': 'unc', 51 | 'refcoco+': 'unc+', 52 | 'refcocog': 'gref', 53 | 'refcocog_umd': 'gref_umd', 54 | 'flickr': 'flickr', 55 | 'copsref': 'copsref' 56 | } 57 | 58 | def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=512, transform=None, testmode=False, split='train', max_query_len=20, augment=False): 59 | self.images = [] 60 | self.data_root = data_root 61 | self.split_root = split_root 62 | self.dataset = dataset 63 | self.imsize = imsize 64 | self.query_len = max_query_len 65 | self.transform = transform 66 | self.testmode = testmode 67 | self.split = split 68 | self.trans = trans if augment else trans_simple 69 | 70 | if self.dataset == 'flickr': 71 | self.dataset_root = osp.join(self.data_root, 'Flickr30k') 72 | self.im_dir = osp.join(self.dataset_root, 'flickr30k-images') 73 | elif self.dataset == 'copsref': 74 | self.dataset_root = osp.join(self.data_root, 'copsref') 75 | self.im_dir = osp.join(self.dataset_root, 'images') 76 | else: 77 | self.dataset_root = osp.join(self.data_root, 'other') 78 | self.im_dir = osp.join( 79 | self.dataset_root, 'images', 'mscoco', 'images', 'train2014') 80 | self.split_dir = osp.join(self.dataset_root, 'splits') 81 | 82 | 83 | 84 | self.sup_set = self.dataset 85 | self.dataset = self.MAPPING[self.dataset] 86 | 87 | if not self.exists_dataset(): 88 | print('Please download index cache to data folder: \n \ 89 | https://drive.google.com/open?id=1cZI562MABLtAzM6YU4WmKPFFguuVr0lZ') 90 | exit(0) 91 | 92 | dataset_path = osp.join(self.split_root, self.dataset) 93 | valid_splits = self.SUPPORTED_DATASETS[self.sup_set]['splits'] 94 | 95 | self.corpus = Corpus() 96 | corpus_path = osp.join(dataset_path, 'corpus.pth') 97 | self.corpus = torch.load(corpus_path) 98 | 99 | if split not in valid_splits: 100 | raise ValueError( 101 | 'Dataset {0} does not have split {1}'.format( 102 | self.dataset, split)) 103 | 104 | # splits = [split] 105 | splits = ['train', 'val'] if split == 'trainval' else [split] 106 | for split in splits: 107 | imgset_file = '{0}_{1}.pth'.format(self.dataset, split) 108 | imgset_path = osp.join(dataset_path, imgset_file) 109 | self.images += torch.load(imgset_path) 110 | 111 | def exists_dataset(self): 112 | 113 | return osp.exists(osp.join(self.split_root, self.dataset)) 114 | 115 | def __len__(self): 116 | return len(self.images) 117 | 118 | def __getitem__(self, index): 119 | """ 120 | 121 | :return: (img, phrase word id, phrase word mask, bounding bbox) 122 | """ 123 | if self.dataset == 'flickr' or self.dataset == 'copsref': 124 | img_file, bbox, phrase = self.images[index] 125 | else: 126 | img_file, _, bbox, phrase, _ = self.images[index] 127 | 128 | if not self.dataset == 'flickr': 129 | bbox = np.array(bbox, dtype=int) 130 | bbox[2], bbox[3] = bbox[0]+bbox[2], bbox[1]+bbox[3] 131 | else: 132 | bbox = np.array(bbox, dtype=int) 133 | 134 | img_path = osp.join(self.im_dir, img_file) 135 | img = cv2.imread(img_path) 136 | 137 | if img.shape[-1] > 1: 138 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 139 | else: 140 | img = np.stack([img] * 3) 141 | 142 | phrase = phrase.lower() 143 | 144 | img, phrase, bbox = self.trans(img, phrase, bbox, self.imsize) 145 | 146 | if self.transform is not None: 147 | img = self.transform(img) 148 | 149 | # tokenize phrase 150 | word_id = self.corpus.tokenize(phrase, self.query_len) 151 | word_mask = np.array(word_id > 0, dtype=int) 152 | 153 | return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), np.array(bbox, dtype=np.float32) 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Grounding with Transformers 2 | 3 | 4 | ## Overview 5 | 6 | This repository includes PyTorch implementation and trained models of VGTR(**V**isual **G**rounding with **TR**ansformers). 7 | 8 | [[arXiv](https://arxiv.org/abs/2105.04281)] 9 | 10 | 11 | >In this paper, we propose a transformer based approach for visual grounding. Unlike existing proposal-and-rank frameworks that rely heavily on pretrained object detectors or proposal-free frameworks that upgrade an off-the-shelf one-stage detector by fusing textual embeddings, our approach is built on top of a transformer encoder-decoder and is independent of any pretrained detectors or word embedding models. Termed as VGTR – Visual Grounding with TRansformers, our approach is designed to learn semantic-discriminative visual features under the guidance of the textual description without harming their location ability. This information flow enables our VGTR to have a strong capability in capturing context-level semantics of both vision and language modalities, rendering us to aggregate accurate visual clues implied by the description to locate the interested object instance. Experiments show that our method outperforms state-of-the-art proposal-free approaches by a considerable margin on four benchmarks. 12 | 13 | 图片 14 | 15 | 16 | ## Prerequisites 17 | 18 | - python 3.6 19 | - pytorch>=1.6.0 20 | - torchvision 21 | - CUDA>=9.0 22 | - others (opencv-python etc.) 23 | 24 | 25 | ## Preparation 26 | 27 | 1. Clone this repository. 28 | 29 | 2. Data preparation. 30 | 31 | Download Flickr30K Entities from [Flickr30k Entities (bryanplummer.com)](http://bryanplummer.com/Flickr30kEntities/) and [Flickr30K](http://shannon.cs.illinois.edu/DenotationGraph/) 32 | 33 | Download MSCOCO images from [MSCOCO](http://images.cocodataset.org/zips/train2014.zip) 34 | 35 | Download processed indexes from [Gdrive](https://drive.google.com/drive/folders/1cZI562MABLtAzM6YU4WmKPFFguuVr0lZ?usp=drive_open), process by [zyang-ur 36 | ](https://github.com/zyang-ur/onestage_grounding). 37 | 38 | 3. Download backbone weights. We use resnet-50/101 as the basic visual encoder. The weights are pretrained on MSCOCO, and can be downloaded here (BaiduDrive): 39 | 40 | [ResNet-50](https://pan.baidu.com/s/1ZHR_Ew8tUZH7gZo1prJThQ)(code:ru8v); [ResNet-101](https://pan.baidu.com/s/1zsQ67cUZQ88n43-nmEjgvA)(code:0hgu). 41 | 42 | 4. Organize all files like this: 43 | 44 | ```bash 45 | . 46 | ├── main.py 47 | ├── store 48 | │   ├── data 49 | │ │ ├── flickr 50 | │ │ │   ├── corpus.pth 51 | │ │ │   └── flickr_train.pth 52 | │ │ ├── gref 53 | │ │ └── gref_umd 54 | │   ├── ln_data 55 | │ │ ├── Flickr30k 56 | │ │ │   └── flickr30k-images 57 | │ │ └── other 58 | │ │    └── images 59 | │   ├── pretrained 60 | │   │   └── flickr_R50.pth.tar 61 | │   └── pth 62 | │   └── resnet50_detr.pth 63 | └── work 64 | ``` 65 | 66 | 67 | ## Model Zoo 68 | 69 | | Dataset | Backbone | Accuracy | Pretrained Model (BaiduDrive) | 70 | | ----------------- | --------- | ------------------- | ------------------------------------------------------------ | 71 | | Flickr30K Entites | Resnet50 | 74.17 | [flickr_R50.pth.tar](https://pan.baidu.com/s/1VUnxD-5pXnM7iFwIl8q9kA) code: rpdr | 72 | | Flickr30K Entites | Resnet101 | 75.32 | [flickr_R101.pth.tar](https://pan.baidu.com/s/10GcUFLSTei9Lwvu4e5GjrQ) code: 1igb | 73 | | RefCOCO | Resnet50 | 78.70 82.09 73.31 | [refcoco_R50.pth.tar](https://pan.baidu.com/s/1GIe5OoOQOADYc1vVGcSXbw) code: xjs8 | 74 | | RefCOCO | Resnet101 | 79.30 82.16 74.38 | [refcoco_R101.pth.tar](https://pan.baidu.com/s/1GL-itH93G_e3VVNUPtocSA) code: bv0z | 75 | | RefCOCO+ | Resnet50 | 63.57 69.65 55.33 | [refcoco+_R50.pth.tar](https://pan.baidu.com/s/1PUF8WoTrOLmYU24kgAMXKQ) code: 521n | 76 | | RefCOCO+ | Resnet101 | 64.40 70.85 55.84 | [refcoco+_R101.pth.tar](https://pan.baidu.com/s/1mJiA7i7-Mp5ZL5D6dEDy0g) code: vzld | 77 | | RefCOCOg | Resnet50 | 62.88 | [refcocog_R50.pth.tar](https://pan.baidu.com/s/1KvDPisgSLzy8u5bIVCBiOg) code: wb3x | 78 | | RefCOCOg | Resnet101 | 64.05 | [refcocog_R101.pth.tar](https://pan.baidu.com/s/13ubLIbIUA3XlhzSOjaK7dg) code: 5ok2 | 79 | | RefCOCOg-umd | Resnet50 | 65.62 65.30 | [umd_R50.pth.tar](https://pan.baidu.com/s/1-PgzbA98rUOl7VJHAO-Exw) code: 9lzr | 80 | | RefCOCOg-umd | Resnet101 | 66.83 67.28 | [umd_R101.pth.tar](https://pan.baidu.com/s/1JkGbYL8Of3WOVWI9QcVwhQ) code: zen0 | 81 | 82 | 83 | ## Train 84 | 85 | ``` 86 | python main.py \ 87 | --gpu $gpu_id \ 88 | --dataset $[refcoco | refcoco+ | others] \ 89 | --batch_size $bs \ 90 | --savename $exp_name \ 91 | --backbone $[resnet50 | resnet101] \ 92 | --cnn_path $resnet_coco_weight_path 93 | ``` 94 | 95 | 96 | ## Inference 97 | 98 | Download the pretrained models and put it into the folder ```./store/pretrained/```. 99 | 100 | ``` 101 | python main.py \ 102 | --test \ 103 | --gpu $gpu_id \ 104 | --dataset $[refcoco | refcoco+ | others] \ 105 | --batch_size $bs \ 106 | --pretrain $pretrained_weight_path 107 | ``` 108 | 109 | ## Acknowledgements 110 | 111 | Part of codes are from: 112 | 113 | 1. [facebookresearch/detr](https://github.com/facebookresearch/detr); 114 | 2. [zyang-ur/onestage_grounding](https://github.com/zyang-ur/onestage_grounding); 115 | 3. [andfoy/refer](https://github.com/andfoy/refer); 116 | 4. [jadore801120/attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch). 117 | 118 | 119 | 120 | ## Citation 121 | ``` 122 | @article{du2021visual, 123 | title={Visual grounding with transformers}, 124 | author={Du, Ye and Fu, Zehua and Liu, Qingjie and Wang, Yunhong}, 125 | journal={arXiv preprint arXiv:2105.04281}, 126 | year={2021} 127 | } 128 | 129 | @inproceedings{du2022visual, 130 | title={Visual grounding with transformers}, 131 | author={Du, Ye and Fu, Zehua and Liu, Qingjie and Wang, Yunhong}, 132 | booktitle={Proceedings of the International Conference on Multimedia and Expo}, 133 | year={2022} 134 | } 135 | ``` 136 | 137 | -------------------------------------------------------------------------------- /work/utils/parsing_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | 7 | # from plot_util import plot_confusion_matrix 8 | # from makemask import * 9 | 10 | def _fast_hist(label_true, label_pred, n_class): 11 | mask = (label_true >= 0) & (label_true < n_class) 12 | hist = np.bincount( 13 | n_class * label_true[mask].astype(int) + 14 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 15 | return hist 16 | 17 | def label_accuracy_score(label_trues, label_preds, n_class, bg_thre=200): 18 | """Returns accuracy score evaluation result. 19 | - overall accuracy 20 | - mean accuracy 21 | - mean IU 22 | - fwavacc 23 | """ 24 | hist = np.zeros((n_class, n_class)) 25 | for lt, lp in zip(label_trues, label_preds): 26 | # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 27 | hist += _fast_hist(lt[lt 0] * iu[freq > 0]).sum() 35 | return acc, acc_cls, mean_iu, fwavacc 36 | 37 | def label_confusion_matrix(label_trues, label_preds, n_class, bg_thre=200): 38 | # eps=1e-20 39 | hist=np.zeros((n_class,n_class),dtype=float) 40 | """ (8,256,256), (256,256) """ 41 | for lt,lp in zip(label_trues, label_preds): 42 | # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 43 | hist += _fast_hist(lt[lt 0] * iu[freq > 0]).sum() 74 | return acc, acc_cls, mean_iu, fwavacc, iu 75 | 76 | # if __name__ == '__main__': 77 | # """ Evaluating from saved png segmentation maps 78 | # 0.862723060822 0.608076070823 0.503493670787 0.76556929118 79 | # """ 80 | # import csv 81 | # from PIL import Image 82 | # import matplotlib as mpl 83 | # mpl.use('Agg') 84 | # from matplotlib import pyplot as plt 85 | # eps=1e-20 86 | 87 | # class AverageMeter(object): 88 | # """Computes and stores the average and current value""" 89 | # def __init__(self): 90 | # self.reset() 91 | 92 | # def reset(self): 93 | # self.val = 0 94 | # self.avg = 0 95 | # self.sum = 0 96 | # self.count = 0 97 | 98 | # def update(self, val, n=1): 99 | # self.val = val 100 | # self.sum += val * n 101 | # self.count += n 102 | # self.avg = self.sum / self.count 103 | # def load_csv(csv_file): 104 | # img_list, kpt_list, conf_list=[],[],[] 105 | # with open(csv_file, 'rb') as f: 106 | # reader = csv.reader(f) 107 | # for row in reader: 108 | # img_list.append(row[0]) 109 | # kpt_list.append([row[i] for i in range(1,len(row)) if i%3!=0]) 110 | # conf_list.append([row[i] for i in range(1,len(row)) if i%3==0]) 111 | # # print len(img_list),len(kpt_list[0]),len(conf_list[0]) 112 | # return img_list,kpt_list,conf_list 113 | 114 | # n_class = 7 115 | # superpixel_smooth = False 116 | # # valfile = '../../ln_data/LIP/TrainVal_pose_annotations/lip_val_set.csv' 117 | # # pred_folder = '../../../git_code/LIP_JPPNet/output/parsing/val/' 118 | # # pred_folder = '../visulizations/refinenet_baseline/test_out/' 119 | # pred_folder = '../visulizations/refinenet_splittask/test_out/' 120 | # gt_folder = '../../ln_data/pascal_data/SegmentationPart/' 121 | # img_path = '../../ln_data/pascal_data/JPEGImages/' 122 | 123 | # file = '../../ln_data/pascal_data/val_id.txt' 124 | # missjoints = '../../ln_data/pascal_data/no_joint_list.txt' 125 | # img_list = [x.strip().split(' ')[0] for x in open(file)] 126 | # miss_list = [x.strip().split(' ')[0] for x in open(missjoints)] 127 | 128 | # conf_matrices = AverageMeter() 129 | # for index in range(len(img_list)): 130 | # img_name = img_list[index] 131 | # if img_name in miss_list: 132 | # continue 133 | # if not os.path.isfile(pred_folder + img_name + '.png'): 134 | # continue 135 | # pred_file = pred_folder + img_name + '.png' 136 | # pred = Image.open(pred_file) 137 | # gt_file = gt_folder + img_name + '.png' 138 | # gt = Image.open(gt_file) 139 | # pred, gt = np.array(pred, dtype=np.int32), np.array(gt, dtype=np.int32) 140 | # if superpixel_smooth: 141 | # img_file = img_path+img_name+'.jpg' 142 | # img = Image.open(img_file) 143 | # pred = superpixel_expand(np.array(img),pred) 144 | # confusion, _ = label_confusion_matrix(gt, pred, n_class) 145 | # conf_matrices.update(confusion,1) 146 | # acc, acc_cls, mean_iu, fwavacc, iu = hist_based_accu_cal(conf_matrices.avg) 147 | # print(acc, acc_cls, mean_iu, fwavacc) 148 | # print(iu) 149 | 150 | # ## SAVE CONFUSION MATRIX 151 | # figure=plt.figure() 152 | # class_name=['bg', 'head', 'torso', 'upper arm', 'lower arm', 'upper leg', 'lower leg'] 153 | # conf_matrices = conf_matrices.avg 154 | # for i in range(n_class): 155 | # conf_matrices[i,:]=(conf_matrices[i,:]+eps)/sum(conf_matrices[i,:]+eps) 156 | # plot_confusion_matrix(conf_matrices, classes=class_name, 157 | # rotation=0, include_text=True, 158 | # title='Confusion matrix, without normalization') 159 | # plt.show() 160 | # plt.savefig('../saved_models/Baseline_refinenet_test.jpg') 161 | # plt.close('all') 162 | -------------------------------------------------------------------------------- /work/model/backbone/visual_backbone.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torchvision 4 | from torch import nn 5 | from torchvision.models._utils import IntermediateLayerGetter 6 | 7 | 8 | class FrozenBatchNorm2d(torch.nn.Module): 9 | """ 10 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 11 | 12 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 13 | without which any other models than torchvision.models.resnet[18,34,50,101] 14 | produce nans. 15 | """ 16 | 17 | def __init__(self, n): 18 | super(FrozenBatchNorm2d, self).__init__() 19 | self.register_buffer("weight", torch.ones(n)) 20 | self.register_buffer("bias", torch.zeros(n)) 21 | self.register_buffer("running_mean", torch.zeros(n)) 22 | self.register_buffer("running_var", torch.ones(n)) 23 | 24 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 25 | missing_keys, unexpected_keys, error_msgs): 26 | num_batches_tracked_key = prefix + 'num_batches_tracked' 27 | if num_batches_tracked_key in state_dict: 28 | del state_dict[num_batches_tracked_key] 29 | 30 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 31 | state_dict, prefix, local_metadata, strict, 32 | missing_keys, unexpected_keys, error_msgs) 33 | 34 | def forward(self, x): 35 | # move reshapes to the beginning 36 | # to make it fuser-friendly 37 | w = self.weight.reshape(1, -1, 1, 1) 38 | b = self.bias.reshape(1, -1, 1, 1) 39 | rv = self.running_var.reshape(1, -1, 1, 1) 40 | rm = self.running_mean.reshape(1, -1, 1, 1) 41 | eps = 1e-5 42 | scale = w * (rv + eps).rsqrt() 43 | bias = b - rm * scale 44 | return x * scale + bias 45 | 46 | 47 | class BackboneBase(nn.Module): 48 | 49 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 50 | super().__init__() 51 | for name, parameter in backbone.named_parameters(): 52 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 53 | parameter.requires_grad_(False) 54 | if return_interm_layers: 55 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 56 | else: 57 | return_layers = {'layer4': "0"} 58 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 59 | self.num_channels = num_channels 60 | 61 | def forward(self, x): 62 | 63 | out = self.body(x) 64 | return_f = [] 65 | return_f.append(out['3']) 66 | return_f.append(out['2']) 67 | return_f.append(out['1']) 68 | return_f.append(out['0']) 69 | 70 | return return_f 71 | 72 | 73 | class Backbone(BackboneBase): 74 | 75 | def __init__(self, name: str, 76 | train_backbone: bool, 77 | return_interm_layers: bool, 78 | dilation: bool, 79 | pretrain_path: str): 80 | 81 | backbone = getattr(torchvision.models, name)( 82 | replace_stride_with_dilation=[False, False, dilation], 83 | pretrained=False, norm_layer=FrozenBatchNorm2d) 84 | backbone.load_state_dict(torch.load(pretrain_path)) 85 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 86 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 87 | 88 | 89 | class Neck(nn.Module): 90 | 91 | def __init__(self, n_levels=4, channels=[2048, 1024, 512, 256], fusion_size=32, lat_channels=128, args=None): 92 | super().__init__() 93 | self.n_levels = n_levels 94 | self.lat_conv = nn.ModuleList([nn.Conv2d(i, lat_channels, 95 | kernel_size=(1, 1)) for i in channels]) 96 | self.updown_conv = nn.ModuleList([nn.Conv2d(lat_channels, lat_channels, 97 | kernel_size=(3, 3), stride=1, padding=1) 98 | for _ in range(n_levels-1)]) 99 | self.fusion_size = fusion_size 100 | n = lat_channels * n_levels 101 | stride = 2 if args.stride else 1 102 | self.post_conv = nn.Sequential( 103 | nn.Conv2d(n, 1024, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2)), # -> (64->32) 104 | nn.BatchNorm2d(1024), 105 | nn.ReLU(), 106 | nn.Conv2d(1024, 2048, kernel_size=(3, 3), padding=(1, 1), stride=(stride, stride)), # -> (32->16) 107 | nn.BatchNorm2d(2048), 108 | nn.ReLU() 109 | ) 110 | self._reset_parameters() 111 | 112 | def _reset_parameters(self): 113 | for p in self.parameters(): 114 | if p.dim() > 1: 115 | nn.init.xavier_uniform_(p) 116 | 117 | def upsample_add(self, feat1, feat2): 118 | _, _, H, W = feat2.size() 119 | return torch.nn.functional.interpolate(feat1, size=(H, W), mode='bilinear', 120 | align_corners=True) + feat2 121 | 122 | def forward(self, feats): 123 | assert len(feats) == self.n_levels 124 | for i in range(self.n_levels): 125 | feats[i] = self.lat_conv[i](feats[i]) 126 | Out = [] 127 | out = feats[0] 128 | out_append = torch.nn.functional.interpolate(out, 129 | size=(self.fusion_size, self.fusion_size), 130 | mode='bilinear', 131 | align_corners=True) 132 | Out.append(out_append) 133 | for i in range(1, self.n_levels): 134 | out = self.updown_conv[i-1](self.upsample_add(out, feats[i])) 135 | out_append = torch.nn.functional.interpolate(out, size=(self.fusion_size, self.fusion_size), 136 | mode='bilinear', 137 | align_corners=True) 138 | Out.append(out_append) 139 | out = torch.cat(Out, dim=1) 140 | out = self.post_conv(out) 141 | return out 142 | 143 | 144 | class VisualBackbone(nn.Module): 145 | 146 | def __init__(self, args): 147 | super().__init__() 148 | self.cnn = Backbone(args.backbone, train_backbone=True, 149 | return_interm_layers=True, 150 | dilation=args.dilation, 151 | pretrain_path=args.cnn_path) 152 | self.neck = Neck(4, [2048, 1024, 512, 256], args=args) 153 | 154 | def forward(self, img): 155 | return self.neck(self.cnn(img)) 156 | 157 | 158 | def build_visual_backbone(args): 159 | return VisualBackbone(args) -------------------------------------------------------------------------------- /work/model/vgtr/vg_encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | import numpy as np 8 | 9 | 10 | class TextGuidedQ(nn.Module): 11 | def __init__(self, d_model=256, l_norm=True): 12 | super(TextGuidedQ, self).__init__() 13 | self.l_norm = l_norm 14 | if l_norm: 15 | self.norm = nn.LayerNorm(d_model) 16 | 17 | def forward(self, exp_f, Q, attn_mask=None): 18 | d_k = Q.shape[-1] 19 | scores = torch.matmul(exp_f, Q.transpose(-1, -2)) / np.sqrt(d_k) 20 | 21 | if attn_mask is not None: 22 | scores.masked_fill_(attn_mask, -1e9) 23 | 24 | attn = torch.nn.functional.softmax(scores.transpose(-1, -2), dim=-1) 25 | context = torch.matmul(attn, exp_f) 26 | 27 | if self.l_norm: 28 | return self.norm(Q + context) 29 | else: 30 | return Q + context 31 | 32 | 33 | class VGEncoder(nn.Module): 34 | 35 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 36 | dim_feedforward=2048, dropout=0.1, 37 | activation="relu"): 38 | super().__init__() 39 | 40 | self.hidden_dim = d_model 41 | self.d_model = d_model 42 | self.nhead = nhead 43 | 44 | encoder_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 45 | encoder_norm = nn.LayerNorm(d_model) 46 | encoder_norm2 = nn.LayerNorm(d_model) 47 | self.encoder = Encoder(encoder_layer, num_encoder_layers, encoder_norm, encoder_norm2) 48 | 49 | self._reset_parameters() 50 | 51 | def _reset_parameters(self): 52 | for p in self.parameters(): 53 | if p.dim() > 1: 54 | nn.init.xavier_uniform_(p) 55 | 56 | def forward(self, img_feature, pos_feature, expression_feature, word_id=None, exp_pos_feature=None): 57 | 58 | src = img_feature.flatten(2).permute(2, 0, 1) # (hw, bs, d) 59 | pos_embed = pos_feature.flatten(2).permute(2, 0, 1) 60 | 61 | out, expf = self.encoder(src, expression_feature, pos=pos_embed, exp_pos_feature=exp_pos_feature) 62 | out = out.transpose(0, 1) 63 | 64 | return out, expf 65 | 66 | 67 | class Encoder(nn.Module): 68 | 69 | def __init__(self, encoder_layer, num_layers, norm=None, norm2=None): 70 | super().__init__() 71 | self.layers = _get_clones(encoder_layer, num_layers) 72 | self.num_layers = num_layers 73 | self.norm = norm 74 | self.norm2 = norm2 75 | 76 | def forward(self, src, 77 | expression_feature, 78 | pos: Optional[Tensor] = None, 79 | exp_pos_feature=None): 80 | 81 | output = src 82 | exp_f = expression_feature 83 | 84 | for layer in self.layers: 85 | output, exp_f = layer(output, exp_f, pos=pos, exp_pos_feature=exp_pos_feature) 86 | if self.norm is not None: 87 | output = self.norm(output) 88 | exp_f = self.norm2(exp_f) 89 | 90 | return output, exp_f 91 | 92 | 93 | class EncoderLayer(nn.Module): 94 | 95 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 96 | activation="relu"): 97 | super().__init__() 98 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 99 | self.exp_self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 100 | self.linear1 = nn.Linear(d_model, dim_feedforward) 101 | self.dropout = nn.Dropout(dropout) 102 | self.linear2 = nn.Linear(dim_feedforward, d_model) 103 | self.norm1 = nn.LayerNorm(d_model) 104 | self.norm2 = nn.LayerNorm(d_model) 105 | self.dropout1 = nn.Dropout(dropout) 106 | self.dropout2 = nn.Dropout(dropout) 107 | self.activation = _get_activation_fn(activation) 108 | self.exp_self_norm1 = nn.LayerNorm(d_model) 109 | self.exp_self_norm2 = nn.LayerNorm(d_model) 110 | self.exp_self_dropout1 = nn.Dropout(dropout) 111 | self.expression_ffn_linear1 = nn.Linear(in_features=d_model, out_features=dim_feedforward) 112 | self.expression_ffn_dropout = nn.Dropout(dropout) 113 | self.expression_ffn_linear2 = nn.Linear(dim_feedforward, d_model) 114 | self.expression_ffn_dropout2 = nn.Dropout(dropout) 115 | self.expression_ffn_activation = _get_activation_fn(activation) 116 | self.text_guided = TextGuidedQ(d_model=d_model) 117 | 118 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 119 | return tensor if pos is None else tensor + pos 120 | 121 | def forward(self, 122 | src, 123 | expression_feature, 124 | pos: Optional[Tensor] = None, 125 | exp_pos_feature=None): 126 | 127 | # self-attn for exp feature 128 | expression_feature = expression_feature.permute(1, 0, 2) 129 | expression_feature2 = self.exp_self_norm1(expression_feature) 130 | exp_q = exp_k = self.with_pos_embed(expression_feature2, exp_pos_feature) 131 | expression_feature2 = self.exp_self_attn(exp_q, exp_k, value=expression_feature2)[0] # (maxL, bs, d) 132 | expression_feature = expression_feature + self.exp_self_dropout1(expression_feature2) 133 | expression_feature = self.exp_self_norm2(expression_feature) 134 | expression_feature = expression_feature.permute(1, 0, 2) 135 | 136 | expression_feature = expression_feature # (bs, maxL, d) 137 | 138 | # self-attn for img feature 139 | src2 = self.norm1(src) 140 | q = k = self.with_pos_embed(src2, pos) # q: (hw, bs, d) 141 | # text guided 142 | q = q.transpose(0, 1) 143 | q = self.text_guided(expression_feature, q).transpose(0, 1) 144 | 145 | src2 = self.self_attn(q, k, value=src2)[0] 146 | src = src + self.dropout1(src2) 147 | src2 = self.norm2(src) 148 | fused_vis_feature = src2 149 | fused_expression_feature = expression_feature # (bs, maxL, d) 150 | # FFN 151 | src2 = self.linear2(self.dropout(self.activation(self.linear1(fused_vis_feature)))) 152 | src = fused_vis_feature + self.dropout2(src2) 153 | 154 | # FFN 155 | expression_feature2 = self.expression_ffn_linear2(self.expression_ffn_dropout( 156 | self.expression_ffn_activation(self.expression_ffn_linear1(fused_expression_feature)))) 157 | expression_feature = fused_expression_feature + self.expression_ffn_dropout2(expression_feature2) 158 | 159 | return src, expression_feature 160 | 161 | # self.forward(src, expression_feature, origin_h, origin_w, 162 | # word_mask, src_mask, src_key_padding_mask, pos, exp_pos_feature) 163 | 164 | 165 | def _get_clones(module, N): 166 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 167 | 168 | 169 | def _get_activation_fn(activation): 170 | if activation == "relu": 171 | return F.relu 172 | if activation == "gelu": 173 | return F.gelu 174 | if activation == "glu": 175 | return F.glu 176 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /work/engine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import time 4 | import logging 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | from .utils.utils import AverageMeter, xywh2xyxy, bbox_iou 9 | 10 | 11 | def train_epoch(args, train_loader, model, optimizer, epoch, criterion=None, img_size=512): 12 | 13 | batch_time = AverageMeter() 14 | losses = AverageMeter() 15 | 16 | losses_bbox = AverageMeter() 17 | losses_giou = AverageMeter() 18 | 19 | acc = AverageMeter() 20 | miou = AverageMeter() 21 | 22 | model.train() 23 | end = time.time() 24 | 25 | for batch_idx, (imgs, word_id, word_mask, bbox) in enumerate(train_loader): 26 | imgs = imgs.cuda() 27 | word_id = word_id.cuda() 28 | bbox = bbox.cuda() 29 | bbox = torch.clamp(bbox, min=0, max=args.size - 1) 30 | image = Variable(imgs) 31 | word_id = Variable(word_id) 32 | bbox = Variable(bbox) 33 | 34 | norm_bbox = torch.zeros_like(bbox).cuda() 35 | 36 | norm_bbox[:, 0] = (bbox[:, 0] + bbox[:, 2]) / 2.0 # x_center 37 | norm_bbox[:, 1] = (bbox[:, 1] + bbox[:, 3]) / 2.0 # y_center 38 | norm_bbox[:, 2] = bbox[:, 2] - bbox[:, 0] # w 39 | norm_bbox[:, 3] = bbox[:, 3] - bbox[:, 1] # h 40 | 41 | # forward 42 | pred_box = model(image, word_id) # [bs, C, H, W] 43 | loss, loss_box, loss_giou = criterion(pred_box, norm_bbox, img_size=img_size) 44 | 45 | optimizer.zero_grad() 46 | loss.backward() 47 | optimizer.step() 48 | 49 | # pred-box 50 | pred_bbox = pred_box.detach().cpu() 51 | pred_bbox = pred_bbox * img_size 52 | pred_box = xywh2xyxy(pred_bbox) 53 | 54 | losses.update(loss.item(), imgs.size(0)) 55 | losses_bbox.update(loss_box.item(), imgs.size(0)) 56 | losses_giou.update(loss_giou.item(), imgs.size(0)) 57 | 58 | target_bbox = bbox 59 | iou = bbox_iou(pred_box, target_bbox.data.cpu(), x1y1x2y2=True) 60 | accu = np.sum(np.array((iou.data.cpu().numpy() > 0.5), dtype=float)) / args.batch_size 61 | 62 | # metrics 63 | miou.update(torch.mean(iou).item(), imgs.size(0)) 64 | acc.update(accu, imgs.size(0)) 65 | 66 | batch_time.update(time.time() - end) 67 | end = time.time() 68 | 69 | if (batch_idx+1) % args.print_freq == 0: 70 | print_str = 'Epoch: [{0}][{1}/{2}]\t' \ 71 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 72 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ 73 | 'Loss_bbox {loss_box.val:.4f} ({loss_box.avg:.4f})\t' \ 74 | 'Loss_giou {loss_giou.val:.4f} ({loss_giou.avg:.4f})\t' \ 75 | 'Accu {acc.val:.4f} ({acc.avg:.4f})\t' \ 76 | 'Mean_iu {miou.val:.4f} ({miou.avg:.4f})\t' \ 77 | .format(epoch+1, batch_idx+1, len(train_loader), 78 | batch_time=batch_time, 79 | loss=losses, 80 | loss_box=losses_bbox, 81 | loss_giou=losses_giou, 82 | acc=acc, 83 | miou=miou) 84 | 85 | print(print_str) 86 | logging.info(print_str) 87 | 88 | 89 | def validate_epoch(args, val_loader, model, train_epoch, img_size=512): 90 | 91 | batch_time = AverageMeter() 92 | acc = AverageMeter() 93 | miou = AverageMeter() 94 | 95 | model.eval() 96 | end = time.time() 97 | 98 | for batch_idx, (imgs, word_id, word_mask, bbox) in enumerate(val_loader): 99 | imgs = imgs.cuda() 100 | word_id = word_id.cuda() 101 | bbox = bbox.cuda() 102 | image = Variable(imgs) 103 | word_id = Variable(word_id) 104 | bbox = Variable(bbox) 105 | bbox = torch.clamp(bbox, min=0, max=args.size-1) 106 | 107 | norm_bbox = torch.zeros_like(bbox).cuda() 108 | 109 | norm_bbox[:, 0] = (bbox[:, 0] + bbox[:, 2]) / 2.0 # x_center 110 | norm_bbox[:, 1] = (bbox[:, 1] + bbox[:, 3]) / 2.0 # y_center 111 | norm_bbox[:, 2] = bbox[:, 2] - bbox[:, 0] # w 112 | norm_bbox[:, 3] = bbox[:, 3] - bbox[:, 1] # h 113 | 114 | with torch.no_grad(): 115 | pred_box = model(image, word_id) # [bs, C, H, W] 116 | 117 | pred_bbox = pred_box.detach().cpu() 118 | pred_bbox = pred_bbox * img_size 119 | pred_bbox = xywh2xyxy(pred_bbox) 120 | 121 | # constrain 122 | pred_bbox[pred_bbox < 0.0] = 0.0 123 | pred_bbox[pred_bbox > img_size-1] = img_size-1 124 | 125 | target_bbox = bbox 126 | # metrics 127 | iou = bbox_iou(pred_bbox, target_bbox.data.cpu(), x1y1x2y2=True) 128 | # accu = np.sum(np.array((iou.data.cpu().numpy() > 0.5), dtype=float)) / args.batch_size 129 | accu = np.sum(np.array((iou.data.cpu().numpy() > 0.5), dtype=float)) / imgs.size(0) 130 | 131 | acc.update(accu, imgs.size(0)) 132 | miou.update(torch.mean(iou).item(), imgs.size(0)) 133 | 134 | batch_time.update(time.time() - end) 135 | end = time.time() 136 | 137 | if (batch_idx+1) % (args.print_freq//10) == 0: 138 | print_str = 'Validate: [{0}/{1}]\t' \ 139 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' \ 140 | 'Acc {acc.val:.4f} ({acc.avg:.4f}) ' \ 141 | 'Mean_iu {miou.val:.4f} ({miou.avg:.4f}) ' \ 142 | .format(batch_idx+1, len(val_loader), batch_time=batch_time, acc=acc, miou=miou) 143 | 144 | print(print_str) 145 | logging.info(print_str) 146 | 147 | print(f"Train_epoch {train_epoch+1} Validate Result: Acc {acc.avg}, MIoU {miou.avg}.") 148 | 149 | logging.info("Validate: %f, %f" % (acc.avg, float(miou.avg))) 150 | 151 | return acc.avg, miou.avg 152 | 153 | def test_epoch(test_loader, model, img_size=512): 154 | 155 | acc = AverageMeter() 156 | miou = AverageMeter() 157 | model.eval() 158 | 159 | for batch_idx, (imgs, word_id, word_mask, bbox) in enumerate(test_loader): 160 | imgs = imgs.cuda() 161 | word_id = word_id.cuda() 162 | bbox = bbox.cuda() 163 | image = Variable(imgs) 164 | word_id = Variable(word_id) 165 | bbox = Variable(bbox) 166 | bbox = torch.clamp(bbox, min=0, max=img_size-1) 167 | 168 | norm_bbox = torch.zeros_like(bbox).cuda() 169 | 170 | norm_bbox[:, 0] = (bbox[:, 0] + bbox[:, 2]) / 2.0 # x_center 171 | norm_bbox[:, 1] = (bbox[:, 1] + bbox[:, 3]) / 2.0 # y_center 172 | norm_bbox[:, 2] = bbox[:, 2] - bbox[:, 0] # w 173 | norm_bbox[:, 3] = bbox[:, 3] - bbox[:, 1] # h 174 | 175 | with torch.no_grad(): 176 | pred_box = model(image, word_id) # [bs, C, H, W] 177 | 178 | pred_bbox = pred_box.detach().cpu() 179 | pred_bbox = pred_bbox * img_size 180 | pred_bbox = xywh2xyxy(pred_bbox) 181 | 182 | # constrain 183 | pred_bbox[pred_bbox < 0.0] = 0.0 184 | pred_bbox[pred_bbox > img_size-1] = img_size-1 185 | 186 | target_bbox = bbox 187 | # metrics 188 | iou = bbox_iou(pred_bbox, target_bbox.data.cpu(), x1y1x2y2=True) 189 | accu = np.sum(np.array((iou.data.cpu().numpy() > 0.5), dtype=float)) / imgs.size(0) 190 | 191 | acc.update(accu, imgs.size(0)) 192 | miou.update(torch.mean(iou).item(), imgs.size(0)) 193 | 194 | print(f"Test Result: Acc {acc.avg}, MIoU {miou.avg}.") 195 | -------------------------------------------------------------------------------- /work/model/criterion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import math 8 | 9 | class Criterion(nn.Module): 10 | def __init__(self, args): 11 | super(Criterion, self).__init__() 12 | self.loss_weight = [3, 1] 13 | self.MSELoss = torch.nn.MSELoss(reduction='none') 14 | def forward(self, pred, gt, img_size=256): 15 | """` 16 | :param pred: (bs, 4) 17 | :param gt: (bs, 4) 18 | :return: 19 | """ 20 | bs = pred.shape[0] 21 | gt = gt / img_size 22 | 23 | loss_bbox = F.l1_loss(pred, gt, reduction='none') 24 | loss_bbox = loss_bbox.sum() / bs 25 | 26 | loss_giou = 1 - torch.diag(self.generalized_box_iou( 27 | self.box_cxcywh_to_xyxy(pred), 28 | self.box_cxcywh_to_xyxy(gt))) 29 | 30 | loss_giou = loss_giou.sum() / bs 31 | loss = 5 * loss_bbox + loss_giou * 2 32 | return loss, 5 * loss_bbox, loss_giou * 2 33 | 34 | 35 | def box_loss(self, pred_box, gt_box, type='L2'): 36 | """ 37 | :param pred_box: (bs, 4) (center_x, center_y, h, w) not normalized for L2 loss 38 | :param gt_box: (center_x, center_y, h, w) normalized for L1 loss 39 | :return: 40 | """ 41 | # loss_box = torch.tensor(0.).cuda() 42 | if type == 'L1': 43 | loss_bbox = F.l1_loss(pred_box, gt_box, reduction='none') # element-wise L1 loss 44 | elif type == 'L2': 45 | loss_bbox = self.MSELoss(pred_box, gt_box) 46 | else: 47 | raise NotImplementedError('Not Implemented Loss type') 48 | loss = loss_bbox.sum() / pred_box.shape[0] 49 | return loss 50 | 51 | def diou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): 52 | ''' 53 | https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/loss/multibox_loss.py 54 | :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] 55 | :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] 56 | :param eps: eps to avoid divide 0 57 | :param reduction: mean or sum 58 | :return: diou-loss 59 | ''' 60 | ix1 = torch.max(preds[:, 0], bbox[:, 0]) 61 | iy1 = torch.max(preds[:, 1], bbox[:, 1]) 62 | ix2 = torch.min(preds[:, 2], bbox[:, 2]) 63 | iy2 = torch.min(preds[:, 3], bbox[:, 3]) 64 | 65 | iw = (ix2 - ix1 + 1.0).clamp(min=0.) 66 | ih = (iy2 - iy1 + 1.0).clamp(min=0.) 67 | 68 | # overlaps 69 | inters = iw * ih 70 | 71 | # union 72 | uni = (preds[:, 2] - preds[:, 0] + 1.0) * (preds[:, 3] - preds[:, 1] + 1.0) + ( 73 | bbox[:, 2] - bbox[:, 0] + 1.0) * ( 74 | bbox[:, 3] - bbox[:, 1] + 1.0) - inters 75 | 76 | # iou 77 | iou = inters / (uni + eps) 78 | 79 | # inter_diag 80 | cxpreds = (preds[:, 2] + preds[:, 0]) / 2 81 | cypreds = (preds[:, 3] + preds[:, 1]) / 2 82 | 83 | cxbbox = (bbox[:, 2] + bbox[:, 0]) / 2 84 | cybbox = (bbox[:, 3] + bbox[:, 1]) / 2 85 | 86 | inter_diag = (cxbbox - cxpreds) ** 2 + (cybbox - cypreds) ** 2 87 | 88 | # outer_diag 89 | ox1 = torch.min(preds[:, 0], bbox[:, 0]) 90 | oy1 = torch.min(preds[:, 1], bbox[:, 1]) 91 | ox2 = torch.max(preds[:, 2], bbox[:, 2]) 92 | oy2 = torch.max(preds[:, 3], bbox[:, 3]) 93 | 94 | outer_diag = (ox1 - ox2) ** 2 + (oy1 - oy2) ** 2 95 | 96 | diou = iou - inter_diag / outer_diag 97 | diou = torch.clamp(diou, min=-1.0, max=1.0) 98 | 99 | diou_loss = 1 - diou 100 | 101 | if reduction == 'mean': 102 | loss = torch.mean(diou_loss) 103 | elif reduction == 'sum': 104 | loss = torch.sum(diou_loss) 105 | else: 106 | raise NotImplementedError 107 | return loss 108 | 109 | def ciou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): 110 | ''' 111 | https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/loss/multibox_loss.py 112 | :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] 113 | :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] 114 | :param eps: eps to avoid divide 0 115 | :param reduction: mean or sum 116 | :return: diou-loss 117 | ''' 118 | ix1 = torch.max(preds[:, 0], bbox[:, 0]) 119 | iy1 = torch.max(preds[:, 1], bbox[:, 1]) 120 | ix2 = torch.min(preds[:, 2], bbox[:, 2]) 121 | iy2 = torch.min(preds[:, 3], bbox[:, 3]) 122 | 123 | iw = (ix2 - ix1 + 1.0).clamp(min=0.) 124 | ih = (iy2 - iy1 + 1.0).clamp(min=0.) 125 | 126 | # overlaps 127 | inters = iw * ih 128 | 129 | # union 130 | uni = (preds[:, 2] - preds[:, 0] + 1.0) * (preds[:, 3] - preds[:, 1] + 1.0) + ( 131 | bbox[:, 2] - bbox[:, 0] + 1.0) * ( 132 | bbox[:, 3] - bbox[:, 1] + 1.0) - inters 133 | 134 | # iou 135 | iou = inters / (uni + eps) 136 | 137 | # inter_diag 138 | cxpreds = (preds[:, 2] + preds[:, 0]) / 2 139 | cypreds = (preds[:, 3] + preds[:, 1]) / 2 140 | 141 | cxbbox = (bbox[:, 2] + bbox[:, 0]) / 2 142 | cybbox = (bbox[:, 3] + bbox[:, 1]) / 2 143 | 144 | inter_diag = (cxbbox - cxpreds) ** 2 + (cybbox - cypreds) ** 2 145 | 146 | # outer_diag 147 | ox1 = torch.min(preds[:, 0], bbox[:, 0]) 148 | oy1 = torch.min(preds[:, 1], bbox[:, 1]) 149 | ox2 = torch.max(preds[:, 2], bbox[:, 2]) 150 | oy2 = torch.max(preds[:, 3], bbox[:, 3]) 151 | 152 | outer_diag = (ox1 - ox2) ** 2 + (oy1 - oy2) ** 2 153 | 154 | diou = iou - inter_diag / outer_diag 155 | 156 | # calculate v,alpha 157 | wbbox = bbox[:, 2] - bbox[:, 0] + 1.0 158 | hbbox = bbox[:, 3] - bbox[:, 1] + 1.0 159 | wpreds = preds[:, 2] - preds[:, 0] + 1.0 160 | hpreds = preds[:, 3] - preds[:, 1] + 1.0 161 | v = torch.pow((torch.atan(wbbox / hbbox) - torch.atan(wpreds / hpreds)), 2) * (4 / (math.pi ** 2)) 162 | alpha = v / (1 - iou + v) 163 | ciou = diou - alpha * v 164 | ciou = torch.clamp(ciou, min=-1.0, max=1.0) 165 | 166 | ciou_loss = 1 - ciou 167 | if reduction == 'mean': 168 | loss = torch.mean(ciou_loss) 169 | elif reduction == 'sum': 170 | loss = torch.sum(ciou_loss) 171 | else: 172 | raise NotImplementedError 173 | return loss 174 | 175 | def generalized_box_iou(self, boxes1, boxes2): 176 | """ 177 | Generalized IoU from https://giou.stanford.edu/ 178 | 179 | The boxes should be in [x0, y0, x1, y1] format 180 | 181 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 182 | and M = len(boxes2) 183 | """ 184 | # degenerate boxes gives inf / nan results 185 | # so do an early check 186 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 187 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 188 | iou, union = self.box_iou(boxes1, boxes2) 189 | 190 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 191 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 192 | 193 | wh = (rb - lt).clamp(min=0) # [N,M,2] 194 | area = wh[:, :, 0] * wh[:, :, 1] 195 | 196 | return iou - (area - union) / area 197 | 198 | def box_iou(self, boxes1, boxes2): 199 | area1 = self.box_area(boxes1) 200 | area2 = self.box_area(boxes2) 201 | 202 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 203 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 204 | 205 | wh = (rb - lt).clamp(min=0) # [N,M,2] 206 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 207 | 208 | union = area1[:, None] + area2 - inter 209 | 210 | iou = inter / union 211 | return iou, union 212 | 213 | def box_area(self, boxes): 214 | """ 215 | Computes the area of a set of bounding boxes, which are specified by its 216 | (x1, y1, x2, y2) coordinates. 217 | 218 | Arguments: 219 | boxes (Tensor[N, 4]): boxes for which the area will be computed. They 220 | are expected to be in (x1, y1, x2, y2) format 221 | 222 | Returns: 223 | area (Tensor[N]): area for each box 224 | """ 225 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 226 | 227 | def box_cxcywh_to_xyxy(self, x): 228 | x_c, y_c, w, h = x.unbind(-1) 229 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 230 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 231 | return torch.stack(b, dim=-1) 232 | 233 | def giou_loss(self, pred_box, gt_box): 234 | 235 | loss_giou = 1 - torch.diag(self.generalized_box_iou( 236 | self.box_cxcywh_to_xyxy(pred_box), self.box_cxcywh_to_xyxy(gt_box))) 237 | 238 | return loss_giou.sum() / pred_box.shape[0] 239 | 240 | def focal_loss(self, pred, gt, down_sample=32): 241 | ''' Modified focal loss. Exactly the same as CornerNet. 242 | Runs faster and costs a little bit more memory 243 | Arguments: 244 | pred (batch x c x h x w) [batch_size, c, h, w] 245 | gt: [batch_size, ] 246 | ''' 247 | pred = pred[:, 0, :, :].unsqueeze(1) 248 | gt, down_sample_center = self.gaussian_smooth(pred, gt, down_sample=down_sample) 249 | gt = gt.unsqueeze(1) 250 | pos_inds = gt.eq(1).float() 251 | neg_inds = gt.lt(1).float() 252 | 253 | neg_weights = torch.pow(1 - gt, 4) 254 | 255 | loss = torch.tensor(0.).cuda() 256 | 257 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 258 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds 259 | 260 | num_pos = pos_inds.float().sum() 261 | pos_loss = pos_loss.sum() 262 | neg_loss = neg_loss.sum() 263 | 264 | if num_pos == 0: 265 | loss = loss - neg_loss 266 | else: 267 | loss = loss - (pos_loss + neg_loss) / num_pos 268 | return loss, down_sample_center 269 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import logging 4 | import argparse 5 | import matplotlib as mpl 6 | 7 | import torch.nn.parallel 8 | import torch.optim 9 | import torch.utils.data.distributed 10 | # import torch.backends.cudnn as cudnn 11 | 12 | from work.utils.utils import * 13 | from work.model.criterion import Criterion 14 | from work.model.grounding_model import GroundingModel 15 | from work.engine import train_epoch, validate_epoch, test_epoch 16 | from work.data.get_dataloader import get_train_loader, get_val_loader, get_test_loader 17 | import warnings 18 | mpl.use('Agg') 19 | warnings.filterwarnings('ignore') 20 | 21 | 22 | def getargs(): 23 | 24 | parser = argparse.ArgumentParser( 25 | description='Dataloader test') 26 | parser.add_argument('--num_exp_tokens', default=4, type=int, 27 | help='num of expression tokens of exp feature') 28 | parser.add_argument('--rnn_layers', default=2, type=int, help='num of lstm layers') 29 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 30 | parser.add_argument('--lr_backbone', default=1e-5, type=float, help='learning rate') 31 | parser.add_argument('--hidden_dim', default=256, type=int, 32 | help="Size of the embeddings (dimension of the transformer)") # d_model 33 | parser.add_argument('--size', default=512, type=int, help='image size') 34 | parser.add_argument('--gpu', help='gpu id, split by , ') 35 | parser.add_argument('--workers', default=4, type=int, 36 | help='num workers for data loading') 37 | parser.add_argument('--nb_epoch', default=120, type=int, 38 | help='training epoch') 39 | parser.add_argument('--backbone', default='resnet50', type=str, 40 | help="Name of the convolutional backbone to use") 41 | parser.add_argument('--dilation', action='store_true', 42 | help="If true, we replace stride with dilation in the last CNN convolutional block") 43 | parser.add_argument('--stride', action='store_true', 44 | help="If true, we replace stride with dilation in the last CNN convolutional block") 45 | parser.add_argument('--dataset', default='refcoco', type=str, 46 | help='refcoco/refcoco+/refcocog/refcocog_umd/flickr/copsref') 47 | parser.add_argument('--enc_layers', default=2, type=int, 48 | help="Number of encoding layers in the transformer") 49 | parser.add_argument('--dec_layers', default=2, type=int, 50 | help="Number of decoding layers in the transformer") 51 | parser.add_argument('--dim_feedforward', default=2048, type=int, 52 | help="size of the feedforward layers") 53 | parser.add_argument('--embedding_dim', default=1024, type=int) 54 | parser.add_argument('--rnn_hidden_dim', default=128, type=int) 55 | parser.add_argument('--max_query_len', default=20, type=int, 56 | help="max query len") 57 | parser.add_argument('--dropout', default=0.1, type=float, 58 | help="Dropout applied in the transformer") 59 | parser.add_argument('--nheads', default=8, type=int, 60 | help="Number of attention heads inside the transformer's attentions") 61 | parser.add_argument('--batch_size', default=96, type=int) 62 | parser.add_argument('--weight_decay', default=1e-5, type=float) 63 | parser.add_argument('--clip_max_norm', default=40, type=float, 64 | help='gradient clipping max norm') 65 | parser.add_argument('--data_root', type=str, default='./store/ln_data/', 66 | help='path to ReferIt splits data folder') 67 | parser.add_argument('--split_root', type=str, default='./store/data/', 68 | help='location of pre-parsed dataset info') 69 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 70 | help='checkpoint') 71 | parser.add_argument('--pretrain', default='', type=str, metavar='PATH', 72 | help='pretrain weight') 73 | parser.add_argument('--optimizer', default='adamW', 74 | help='optimizer: awamW, sgd, adam, RMSprop') 75 | parser.add_argument('--savepath', default='store', type=str, 76 | help='save dir for model/logs') 77 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', 78 | help='print frequency (default: 1e3)') 79 | parser.add_argument('--savename', default='default', type=str, 80 | help='Name head for saved model') 81 | parser.add_argument('--test', dest='test', default=False, action='store_true', 82 | help='test mode') 83 | parser.add_argument('--split', default='test', type=str, 84 | help='split subset for test') 85 | parser.add_argument('--cnn_path', default='store/pth/resnet50_detr.pth', type=str, 86 | help='pretrained cnn weights') 87 | args = parser.parse_args() 88 | 89 | # refcoco/refcoco+ 90 | args.split = 'testA' if args.dataset == 'refcoco' or args.dataset == 'refcoco+' else 'test' 91 | # refcocog 92 | args.split = 'val' if args.dataset == 'refcocog' else 'test' 93 | 94 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 95 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 96 | cudnn.benchmark = False 97 | cudnn.deterministic = True 98 | 99 | return args 100 | 101 | 102 | def train(args): 103 | 104 | # log 105 | if args.savename == 'default': 106 | args.savename = f'model_{args.dataset}_batch_{args.batch_size}' 107 | 108 | log_path = f'{args.savepath}/logs/{args.savename}' 109 | if not os.path.exists(log_path): 110 | os.makedirs(log_path) 111 | 112 | logging.basicConfig(level=logging.DEBUG, 113 | filename=f"{log_path}/log_{args.dataset}", 114 | filemode="a+", 115 | format="%(asctime)-15s %(levelname)-8s %(message)s") 116 | logging.info(args) 117 | 118 | # Dataset 119 | train_loader = get_train_loader(args) 120 | val_loader = get_val_loader(args) 121 | 122 | # model 123 | model = GroundingModel(args) 124 | model = torch.nn.DataParallel(model).cuda() 125 | logging.info(model) 126 | 127 | if args.pretrain: 128 | if os.path.isfile(args.pretrain): 129 | pretrained_dict = torch.load(args.pretrain)['state_dict'] 130 | model_dict = model.state_dict() 131 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 132 | assert (len([k for k, v in pretrained_dict.items()]) != 0) 133 | model_dict.update(pretrained_dict) 134 | model.load_state_dict(model_dict) 135 | print("=> loaded pretrain model at {}".format(args.pretrain)) 136 | logging.info("=> loaded pretrain model at {}".format(args.pretrain)) 137 | else: 138 | print(("=> no pretrained file found at '{}'".format(args.pretrain))) 139 | logging.info("=> no pretrained file found at '{}'".format(args.pretrain)) 140 | elif args.resume: 141 | if os.path.isfile(args.resume): 142 | print(("=> loading checkpoint '{}'".format(args.resume))) 143 | logging.info("=> loading checkpoint '{}'".format(args.resume)) 144 | checkpoint = torch.load(args.resume) 145 | args.start_epoch = checkpoint['epoch'] 146 | model.load_state_dict(checkpoint['state_dict']) 147 | print(("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))) 148 | logging.info("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 149 | else: 150 | print(("=> no checkpoint found at '{}'".format(args.resume))) 151 | logging.info(("=> no checkpoint found at '{}'".format(args.resume))) 152 | 153 | # optimizer 154 | optimizer = get_optimizer(args, model) 155 | 156 | # get criterion 157 | criterion = Criterion(args) 158 | best_accu = -float('Inf') 159 | 160 | # train 161 | for epoch in range(args.nb_epoch): 162 | adjust_learning_rate(optimizer, epoch, optimizer.param_groups[0]['lr']) 163 | model.train() 164 | train_epoch(args, train_loader, model, optimizer, epoch, criterion, args.size) 165 | model.eval() 166 | accu_new, miou_new = validate_epoch(args, val_loader, model, epoch, args.size) 167 | 168 | is_best = accu_new > best_accu 169 | best_accu = max(accu_new, best_accu) 170 | # save the pth 171 | save_checkpoint(args, 172 | {'epoch': epoch + 1, 173 | 'state_dict': model.state_dict(), 174 | 'acc': accu_new, 175 | 'optimizer': optimizer.state_dict()}, 176 | is_best, 177 | epoch + 1, 178 | filename=args.savename) 179 | 180 | print(f'Best Acc: {best_accu}.') 181 | 182 | 183 | def test(args): 184 | 185 | # Dataset 186 | if args.batch_size != 1: 187 | warnings.warn('metrics may not correct!', Warning) 188 | 189 | test_loader = get_test_loader(args, split=args.split) 190 | 191 | # model 192 | model = GroundingModel(args) 193 | model = torch.nn.DataParallel(model).cuda() 194 | 195 | assert args.pretrain is not None 196 | if os.path.isfile(args.pretrain): 197 | pretrained_dict = torch.load(args.pretrain)['state_dict'] 198 | model_dict = model.state_dict() 199 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 200 | assert (len([k for k, v in pretrained_dict.items()]) != 0) 201 | model_dict.update(pretrained_dict) 202 | model.load_state_dict(model_dict) 203 | print("=> loaded pretrain model at {}".format(args.pretrain)) 204 | logging.info("=> loaded pretrain model at {}".format(args.pretrain)) 205 | else: 206 | print(("=> no pretrained file found at '{}'".format(args.pretrain))) 207 | logging.info("=> no pretrained file found at '{}'".format(args.pretrain)) 208 | 209 | model.eval() 210 | test_epoch(test_loader, model, args.size) 211 | 212 | 213 | if __name__ == "__main__": 214 | 215 | args = getargs() 216 | if args.test: 217 | print('Starting Test....') 218 | test(args) 219 | else: 220 | print('Starting Training....') 221 | train(args) 222 | 223 | -------------------------------------------------------------------------------- /work/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import random 6 | import shutil 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0. 19 | self.avg = 0. 20 | self.sum = 0. 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def xyxy2xywh(x): # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h] 31 | y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape) 32 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 33 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 34 | y[:, 2] = x[:, 2] - x[:, 0] 35 | y[:, 3] = x[:, 3] - x[:, 1] 36 | return y 37 | 38 | 39 | def xywh2xyxy(x): # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2] 40 | y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape) 41 | y[:, 0] = (x[:, 0] - x[:, 2] / 2) 42 | y[:, 1] = (x[:, 1] - x[:, 3] / 2) 43 | y[:, 2] = (x[:, 0] + x[:, 2] / 2) 44 | y[:, 3] = (x[:, 1] + x[:, 3] / 2) 45 | return y 46 | 47 | 48 | def bbox_iou_numpy(box1, box2): 49 | """Computes IoU between bounding boxes. 50 | Parameters 51 | ---------- 52 | box1 : ndarray 53 | (N, 4) shaped array with bboxes 54 | box2 : ndarray 55 | (M, 4) shaped array with bboxes 56 | Returns 57 | ------- 58 | : ndarray 59 | (N, M) shaped array with IoUs 60 | """ 61 | area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) 62 | 63 | iw = np.minimum(np.expand_dims(box1[:, 2], axis=1), box2[:, 2]) - np.maximum( 64 | np.expand_dims(box1[:, 0], 1), box2[:, 0] 65 | ) 66 | ih = np.minimum(np.expand_dims(box1[:, 3], axis=1), box2[:, 3]) - np.maximum( 67 | np.expand_dims(box1[:, 1], 1), box2[:, 1] 68 | ) 69 | 70 | iw = np.maximum(iw, 0) 71 | ih = np.maximum(ih, 0) 72 | 73 | ua = np.expand_dims((box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]), axis=1) + area - iw * ih 74 | 75 | ua = np.maximum(ua, np.finfo(float).eps) 76 | 77 | intersection = iw * ih 78 | 79 | return intersection / ua 80 | 81 | 82 | def bbox_iou(box1, box2, x1y1x2y2=True): 83 | """ 84 | Returns the IoU of two bounding boxes 85 | """ 86 | if x1y1x2y2: 87 | # Get the coordinates of bounding boxes 88 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] 89 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] 90 | else: 91 | # Transform from center and width to exact coordinates 92 | b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 93 | b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 94 | b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 95 | b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 96 | 97 | # get the coordinates of the intersection rectangle 98 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 99 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 100 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 101 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 102 | # Intersection area 103 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, 0) * torch.clamp(inter_rect_y2 - inter_rect_y1, 0) 104 | # Union Area 105 | b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) 106 | b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) 107 | 108 | # print(box1, box1.shape) 109 | # print(box2, box2.shape) 110 | return inter_area / (b1_area + b2_area - inter_area + 1e-16) 111 | 112 | 113 | def multiclass_metrics(pred, gt): 114 | """ 115 | check precision and recall for predictions. 116 | Output: overall = {precision, recall, f1} 117 | """ 118 | eps=1e-6 119 | overall = {'precision': -1, 'recall': -1, 'f1': -1} 120 | NP, NR, NC = 0, 0, 0 # num of pred, num of recall, num of correct 121 | for ii in range(pred.shape[0]): 122 | pred_ind = np.array(pred[ii]>0.5, dtype=int) 123 | gt_ind = np.array(gt[ii]>0.5, dtype=int) 124 | inter = pred_ind * gt_ind 125 | # add to overall 126 | NC += np.sum(inter) 127 | NP += np.sum(pred_ind) 128 | NR += np.sum(gt_ind) 129 | if NP > 0: 130 | overall['precision'] = float(NC)/NP 131 | if NR > 0: 132 | overall['recall'] = float(NC)/NR 133 | if NP > 0 and NR > 0: 134 | overall['f1'] = 2*overall['precision']*overall['recall']/(overall['precision']+overall['recall']+eps) 135 | return overall 136 | 137 | 138 | def compute_ap(recall, precision): 139 | """ Compute the average precision, given the recall and precision curves. 140 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 141 | # Arguments 142 | recall: The recall curve (list). 143 | precision: The precision curve (list). 144 | # Returns 145 | The average precision as computed in py-faster-rcnn. 146 | """ 147 | # correct AP calculation 148 | # first append sentinel values at the end 149 | mrec = np.concatenate(([0.0], recall, [1.0])) 150 | mpre = np.concatenate(([0.0], precision, [0.0])) 151 | 152 | # compute the precision envelope 153 | for i in range(mpre.size - 1, 0, -1): 154 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 155 | 156 | # to calculate area under PR curve, look for points 157 | # where X axis (recall) changes value 158 | i = np.where(mrec[1:] != mrec[:-1])[0] 159 | 160 | # and sum (\Delta recall) * prec 161 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 162 | return ap 163 | 164 | 165 | def save_segmentation_map(iou, phrase, bbox, target_bbox, input, mode, batch_start_index, \ 166 | merge_pred=None, pred_conf_visu=None, save_path='./visulizations_refcoco/'): 167 | n = input.shape[0] 168 | save_path=save_path+mode 169 | 170 | input=input.data.cpu().numpy() 171 | input=input.transpose(0,2,3,1) 172 | save_txt_path = save_path + 'phrase' 173 | for ii in range(n): 174 | os.system('mkdir -p %s/'%(save_path)) 175 | os.system('mkdir -p %s/' % (save_txt_path)) 176 | imgs = input[ii,:,:,:].copy() 177 | org_imgs = input[ii,:,:,:].copy() 178 | imgs = (imgs*np.array([0.299, 0.224, 0.225])+np.array([0.485, 0.456, 0.406]))*255. 179 | # imgs = imgs.transpose(2,0,1) 180 | imgs = np.array(imgs, dtype=np.float32) 181 | imgs = cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR) 182 | org_imgs = (org_imgs*np.array([0.299, 0.224, 0.225])+np.array([0.485, 0.456, 0.406]))*255. 183 | org_imgs = np.array(org_imgs, dtype=np.float32) 184 | org_imgs = cv2.cvtColor(org_imgs, cv2.COLOR_RGB2BGR) 185 | 186 | cv2.rectangle(imgs, (bbox[ii,0], bbox[ii,1]), (bbox[ii,2], bbox[ii,3]), (255,0,0), 4) 187 | cv2.rectangle(imgs, (target_bbox[ii,0], target_bbox[ii,1]), 188 | (target_bbox[ii,2], target_bbox[ii,3]), (0,255,0), 4) 189 | 190 | cv2.imwrite('%s/pred_gt_%s.png'%(save_path,batch_start_index+ii),imgs) 191 | cv2.imwrite('%s/pred_gt_org_%s.png' % (save_path, batch_start_index + ii), org_imgs) 192 | 193 | with open(os.path.join(save_txt_path, 'phrase_' + str(batch_start_index+ii) + '.txt'), 'w') as f: 194 | f.write(phrase[ii]) 195 | f.write('\n') 196 | f.write(str(iou[ii])) 197 | f.write('\n') 198 | pred1 = str(bbox[ii,0]) + ',' + str(bbox[ii,1]) + ',' + str(bbox[ii,2]) + ',' + str(bbox[ii,3]) 199 | f.write(pred1) 200 | f.write('\n') 201 | gt = str(target_bbox[ii, 0]) + ',' + \ 202 | str(target_bbox[ii, 1]) + ',' + \ 203 | str(target_bbox[ii, 2]) + ',' + str(target_bbox[ii, 3]) 204 | f.write(gt) 205 | 206 | 207 | def lr_poly(base_lr, iter, max_iter, power): 208 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 209 | 210 | 211 | def adjust_learning_rate(optimizer, epoch, lr, drop_list=(70, 100)): 212 | if epoch in drop_list: 213 | lr = lr * 0.1 214 | optimizer.param_groups[0]['lr'] = lr 215 | if len(optimizer.param_groups) > 1: 216 | optimizer.param_groups[1]['lr'] = lr * 0.1 217 | else: 218 | return 219 | 220 | 221 | def save_checkpoint(args, state, is_best, epoch, filename='default'): 222 | if filename == 'default': 223 | filename = f'model_{args.dataset}_batch_{args.batch_size}' 224 | 225 | model_path = f'{args.savepath}/model/{filename}' 226 | if not os.path.exists(model_path): 227 | os.makedirs(model_path) 228 | 229 | # checkpoint_name = f'{model_path}/model_{args.dataset}_Epoch_{epoch}_checkpoint.pth.tar' 230 | checkpoint_name = f'{model_path}/model_{args.dataset}_checkpoint.pth.tar' 231 | best_name = f'{model_path}/model_{args.dataset}_best.pth.tar' 232 | torch.save(state, checkpoint_name) 233 | # pass 234 | if is_best: 235 | shutil.copyfile(checkpoint_name, best_name) 236 | 237 | 238 | 239 | def get_optimizer(args, model): 240 | # optimizer 241 | if args.optimizer == 'adam': 242 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 243 | elif args.optimizer == 'sgd': 244 | visu_param = model.module.visual_encoder.cnn.parameters() # set lr=1e-5 for CNN backbone 245 | rest_param = [param for param in model.parameters() if param not in visu_param] 246 | visu_param = list(model.module.visual_encoder.cnn.parameters()) 247 | optimizer = torch.optim.SGD([{'params': rest_param}, 248 | {'params': visu_param, 'lr': args.lr / 10.}], 249 | lr=args.lr, 250 | momentum=0.9, 251 | weight_decay=args.weight_decay) 252 | 253 | elif args.optimizer == 'adamW': 254 | visu_param = model.module.visual_encoder.cnn.parameters() # set lr=1e-5 for CNN backbone 255 | rest_param = [param for param in model.parameters() if param not in visu_param] 256 | visu_param = list(model.module.visual_encoder.cnn.parameters()) 257 | optimizer = torch.optim.AdamW([{'params': rest_param}, 258 | {'params': visu_param, 'lr': args.lr / 10.}], 259 | lr=args.lr, 260 | weight_decay=args.weight_decay) 261 | 262 | elif args.optimizer == 'RMSprop': 263 | visu_param = model.module.visual_encoder.cnn.parameters() # set lr=1e-5 for CNN backbone 264 | rest_param = [param for param in model.parameters() if param not in visu_param] 265 | visu_param = list(model.module.visual_encoder.cnn.parameters()) 266 | optimizer = torch.optim.RMSprop([{'params': rest_param}, 267 | {'params': visu_param, 'lr': args.lr/10.}], 268 | lr=args.lr, 269 | weight_decay=args.weight_decay) 270 | 271 | else: 272 | raise NotImplementedError('Not Implemented Optimizer') 273 | 274 | return optimizer -------------------------------------------------------------------------------- /work/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import cv2 4 | import math 5 | import random 6 | import numpy as np 7 | from PIL import Image 8 | from collections import Iterable 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torchvision.transforms as transforms 12 | 13 | 14 | def reshape(img, bbox, height): 15 | shape = img.shape[:2] 16 | color = (123.7, 116.3, 103.5) 17 | ratio = float(height) / max(shape) 18 | new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) 19 | dw = (height - new_shape[0]) / 2 # width padding 20 | dh = (height - new_shape[1]) / 2 # height padding 21 | top, bottom = round(dh - 0.1), round(dh + 0.1) 22 | left, right = round(dw - 0.1), round(dw + 0.1) 23 | img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border 24 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square 25 | bbox[0], bbox[2] = bbox[0] * ratio + dw, bbox[2] * ratio + dw 26 | bbox[1], bbox[3] = bbox[1] * ratio + dh, bbox[3] * ratio + dh 27 | 28 | return img, bbox 29 | 30 | 31 | def horizontal_flip(img, phrase, bbox): 32 | w = img.shape[1] 33 | img = cv2.flip(img, 1) 34 | bbox[0], bbox[2] = w - bbox[2] - 1, w - bbox[0] - 1 35 | phrase = phrase.replace('right', '*&^special^&*').replace('left', 'right').replace('*&^special^&*', 'left') 36 | 37 | return img, phrase, bbox 38 | 39 | 40 | def random_affine(img, mask, targets, degrees=(-10, 10), translate=(.1, .1), 41 | scale=(.8, 1.2), shear=(-2, 2), 42 | borderValue=(123.7, 116.3, 103.5), all_bbox=None): 43 | border = 0 # width of added border (optional) 44 | height = max(img.shape[0], img.shape[1]) + border * 2 45 | 46 | # Rotation and Scale 47 | R = np.eye(3) 48 | a = random.random() * (degrees[1] - degrees[0]) + degrees[0] 49 | # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations 50 | s = random.random() * (scale[1] - scale[0]) + scale[0] 51 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s) 52 | 53 | # Translation 54 | T = np.eye(3) 55 | T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border # x translation (pixels) 56 | T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border # y translation (pixels) 57 | 58 | # Shear 59 | S = np.eye(3) 60 | S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg) 61 | S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg) 62 | 63 | M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!! 64 | imw = cv2.warpPerspective(img, M, dsize=(height, height), flags=cv2.INTER_LINEAR, 65 | borderValue=borderValue) # BGR order borderValue 66 | if mask is not None: 67 | maskw = cv2.warpPerspective(mask, M, dsize=(height, height), flags=cv2.INTER_NEAREST, 68 | borderValue=255) # BGR order borderValue 69 | else: 70 | maskw = None 71 | 72 | # Return warped points also 73 | if type(targets)==type([1]): 74 | targetlist=[] 75 | for bbox in targets: 76 | targetlist.append(wrap_points(bbox, M, height, a)) 77 | return imw, maskw, targetlist, M 78 | elif all_bbox is not None: 79 | targets = wrap_points(targets, M, height, a) 80 | for ii in range(all_bbox.shape[0]): 81 | all_bbox[ii,:] = wrap_points(all_bbox[ii,:], M, height, a) 82 | return imw, maskw, targets, all_bbox, M 83 | elif targets is not None: ## previous main 84 | targets = wrap_points(targets, M, height, a) 85 | return imw, maskw, targets, M 86 | else: 87 | return imw 88 | 89 | 90 | def affine(img, bbox, degrees=(-15, 15), translate=(.15, .15), 91 | scale=(.75, 1.25), shear=(-2, 2), 92 | borderValue=(123.7, 116.3, 103.5)): 93 | border = 0 # width of added border (optional) 94 | height = max(img.shape[0], img.shape[1]) + border * 2 95 | 96 | # Rotation and Scale 97 | R = np.eye(3) 98 | a = random.random() * (degrees[1] - degrees[0]) + degrees[0] 99 | # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations 100 | s = random.random() * (scale[1] - scale[0]) + scale[0] 101 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s) 102 | 103 | # Translation 104 | T = np.eye(3) 105 | T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border # x translation (pixels) 106 | T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border # y translation (pixels) 107 | 108 | # Shear 109 | S = np.eye(3) 110 | S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg) 111 | S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg) 112 | 113 | M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!! 114 | imw = cv2.warpPerspective(img, M, dsize=(height, height), flags=cv2.INTER_LINEAR, 115 | borderValue=borderValue) # BGR order borderValue 116 | 117 | # Return warped points also 118 | if type(bbox)==type([1]): 119 | targetlist=[] 120 | for box in bbox: 121 | targetlist.append(wrap_points(box, M, height, a)) 122 | return imw, targetlist 123 | elif bbox is not None: 124 | targets = wrap_points(bbox, M, height, a) 125 | return imw, targets 126 | else: 127 | return imw 128 | 129 | 130 | def generate_transM(img, degrees=(-15, 15), translate=(.15, .15), 131 | scale=(.75, 1.25), shear=(-2, 2)): 132 | # Rotation and Scale 133 | R = np.eye(3) 134 | a = random.random() * (degrees[1] - degrees[0]) + degrees[0] 135 | # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations 136 | s = random.random() * (scale[1] - scale[0]) + scale[0] 137 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s) 138 | 139 | # Translation 140 | T = np.eye(3) 141 | T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] # x translation (pixels) 142 | T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] # y translation (pixels) 143 | 144 | # Shear 145 | S = np.eye(3) 146 | S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg) 147 | S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg) 148 | 149 | M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!! 150 | 151 | return M 152 | 153 | def colorjitter(img): 154 | color_aug = transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.08) 155 | img = color_aug(img) 156 | return img 157 | 158 | 159 | def gauss(img): 160 | scale = 3 161 | sigma = 0.3 * ((scale - 1) * 0.5 - 1) + 0.8 162 | # follow cv2's default routine 163 | if random.random() > 0.5: 164 | cv2.GaussianBlur(img, ksize=(scale, scale), sigmaX=sigma, dst=img) 165 | 166 | return img 167 | 168 | 169 | def wrap_points(targets, M, height, a): 170 | # n = targets.shape[0] 171 | # points = targets[:, 1:5].copy() 172 | points = targets.copy() 173 | area0 = (points[2] - points[0]) * (points[3] - points[1]) 174 | 175 | # warp points 176 | xy = np.ones((4, 3)) 177 | xy[:, :2] = points[[0, 1, 2, 3, 0, 3, 2, 1]].reshape(4, 2) # x1y1, x2y2, x1y2, x2y1 178 | xy = (xy @ M.T)[:, :2].reshape(1, 8) 179 | 180 | # create new boxes 181 | x = xy[:, [0, 2, 4, 6]] 182 | y = xy[:, [1, 3, 5, 7]] 183 | xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, 1).T 184 | 185 | # apply angle-based reduction 186 | radians = a * math.pi / 180 187 | reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 188 | x = (xy[:, 2] + xy[:, 0]) / 2 189 | y = (xy[:, 3] + xy[:, 1]) / 2 190 | w = (xy[:, 2] - xy[:, 0]) * reduction 191 | h = (xy[:, 3] - xy[:, 1]) * reduction 192 | xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, 1).T 193 | 194 | # reject warped points outside of image 195 | np.clip(xy, 0, height, out=xy) 196 | w = xy[:, 2] - xy[:, 0] 197 | h = xy[:, 3] - xy[:, 1] 198 | area = w * h 199 | ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) 200 | i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10) 201 | 202 | ## print(targets, xy) 203 | ## [ 56 36 108 210] [[ 47.80464857 15.6096533 106.30993434 196.71267693]] 204 | # targets = targets[i] 205 | # targets[:, 1:5] = xy[i] 206 | targets = xy[0] 207 | return targets 208 | 209 | 210 | def trans(img, phrase, bbox, imsize): 211 | 212 | img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV) 213 | S = img_hsv[:, :, 1].astype(np.float32) 214 | V = img_hsv[:, :, 2].astype(np.float32) 215 | a = (random.random() * 2 - 1) * 0.5 + 1 216 | # S = S * a 217 | if a >= 1: 218 | np.clip(S, a_min=0, a_max=255, out=S) 219 | a = (random.random() * 2 - 1) * 0.5 + 1 220 | V = V * a 221 | if a >= 1: 222 | np.clip(V, a_min=0, a_max=255, out=V) 223 | img_hsv[:, :, 1] = S.astype(np.uint8) 224 | img_hsv[:, :, 2] = V.astype(np.uint8) 225 | img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB) 226 | img = Image.fromarray(img) 227 | img = colorjitter(img) 228 | img = gauss(np.array(img)) 229 | img, bbox = reshape(img, bbox, imsize) 230 | img, _, bbox, M = random_affine(img, None, bbox, degrees=(-15, 15), translate=(0.15, 0.15), scale=(0.75, 1.25)) 231 | if random.random() > 0.5: 232 | img, phrase, bbox = horizontal_flip(img, phrase, bbox) 233 | 234 | return img, phrase, bbox 235 | 236 | 237 | def trans_simple(img, phrase, bbox, imsize): 238 | 239 | img, bbox = reshape(img, bbox, imsize) 240 | return img, phrase, bbox 241 | 242 | 243 | class ResizePad: 244 | 245 | def __init__(self, size): 246 | if not isinstance(size, (int, Iterable)): 247 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 248 | 249 | self.h, self.w = size 250 | 251 | def __call__(self, img): 252 | h, w = img.shape[:2] 253 | scale = min(self.h / h, self.w / w) 254 | resized_h = int(np.round(h * scale)) 255 | resized_w = int(np.round(w * scale)) 256 | pad_h = int(np.floor(self.h - resized_h) / 2) 257 | pad_w = int(np.floor(self.w - resized_w) / 2) 258 | 259 | resized_img = cv2.resize(img, (resized_w, resized_h)) 260 | 261 | # if img.ndim > 2: 262 | if img.ndim > 2: 263 | new_img = np.zeros( 264 | (self.h, self.w, img.shape[-1]), dtype=resized_img.dtype) 265 | else: 266 | resized_img = np.expand_dims(resized_img, -1) 267 | new_img = np.zeros((self.h, self.w, 1), dtype=resized_img.dtype) 268 | new_img[pad_h: pad_h + resized_h, 269 | pad_w: pad_w + resized_w, ...] = resized_img 270 | return new_img 271 | 272 | 273 | class CropResize: 274 | 275 | def __call__(self, img, size): 276 | if not isinstance(size, (int, Iterable)): 277 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 278 | im_h, im_w = img.data.shape[:2] 279 | input_h, input_w = size 280 | scale = max(input_h / im_h, input_w / im_w) 281 | # scale = torch.Tensor([[input_h / im_h, input_w / im_w]]).max() 282 | resized_h = int(np.round(im_h * scale)) 283 | # resized_h = torch.round(im_h * scale) 284 | resized_w = int(np.round(im_w * scale)) 285 | # resized_w = torch.round(im_w * scale) 286 | crop_h = int(np.floor(resized_h - input_h) / 2) 287 | # crop_h = torch.floor(resized_h - input_h) // 2 288 | crop_w = int(np.floor(resized_w - input_w) / 2) 289 | # crop_w = torch.floor(resized_w - input_w) // 2 290 | # resized_img = cv2.resize(img, (resized_w, resized_h)) 291 | resized_img = F.upsample( 292 | img.unsqueeze(0).unsqueeze(0), size=(resized_h, resized_w), 293 | mode='bilinear') 294 | 295 | resized_img = resized_img.squeeze().unsqueeze(0) 296 | 297 | return resized_img[0, crop_h: crop_h + input_h, 298 | crop_w: crop_w + input_w] 299 | 300 | 301 | class ToNumpy: 302 | 303 | def __call__(self, x): 304 | return x.numpy() 305 | 306 | 307 | class ResizeImage: 308 | """Resize the largest of the sides of the image to a given size""" 309 | def __init__(self, size): 310 | if not isinstance(size, (int, Iterable)): 311 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 312 | 313 | self.size = size 314 | 315 | def __call__(self, img): 316 | im_h, im_w = img.shape[-2:] 317 | scale = min(self.size / im_h, self.size / im_w) 318 | resized_h = int(np.round(im_h * scale)) 319 | resized_w = int(np.round(im_w * scale)) 320 | out = F.upsample( 321 | Variable(img).unsqueeze(0), size=(resized_h, resized_w), 322 | mode='bilinear').squeeze().data 323 | return out 324 | 325 | 326 | class ResizeAnnotation: 327 | """Resize the largest of the sides of the annotation to a given size""" 328 | def __init__(self, size): 329 | if not isinstance(size, (int, Iterable)): 330 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 331 | 332 | self.size = size 333 | 334 | def __call__(self, img): 335 | im_h, im_w = img.shape[-2:] 336 | scale = min(self.size / im_h, self.size / im_w) 337 | resized_h = int(np.round(im_h * scale)) 338 | resized_w = int(np.round(im_w * scale)) 339 | out = F.upsample( 340 | Variable(img).unsqueeze(0).unsqueeze(0), 341 | size=(resized_h, resized_w), 342 | mode='bilinear').squeeze().data 343 | return out 344 | 345 | -------------------------------------------------------------------------------- /work/model/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from torchvision.models.utils import load_state_dict_from_url 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 9 | 'wide_resnet50_2', 'wide_resnet101_2'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=dilation, groups=groups, bias=False, dilation=dilation) 29 | 30 | 31 | def conv1x1(in_planes, out_planes, stride=1): 32 | """1x1 convolution""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 80 | base_width=64, dilation=1, norm_layer=None): 81 | super(Bottleneck, self).__init__() 82 | if norm_layer is None: 83 | norm_layer = nn.BatchNorm2d 84 | width = int(planes * (base_width / 64.)) * groups 85 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 86 | self.conv1 = conv1x1(inplanes, width) 87 | self.bn1 = norm_layer(width) 88 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 89 | self.bn2 = norm_layer(width) 90 | self.conv3 = conv1x1(width, planes * self.expansion) 91 | self.bn3 = norm_layer(planes * self.expansion) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | identity = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | if self.downsample is not None: 111 | identity = self.downsample(x) 112 | 113 | out += identity 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class ResNet(nn.Module): 120 | 121 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 122 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 123 | norm_layer=None): 124 | super(ResNet, self).__init__() 125 | if norm_layer is None: 126 | norm_layer = nn.BatchNorm2d 127 | self._norm_layer = norm_layer 128 | 129 | self.inplanes = 64 130 | self.dilation = 1 131 | if replace_stride_with_dilation is None: 132 | # each element in the tuple indicates if we should replace 133 | # the 2x2 stride with a dilated convolution instead 134 | replace_stride_with_dilation = [False, False, False] 135 | if len(replace_stride_with_dilation) != 3: 136 | raise ValueError("replace_stride_with_dilation should be None " 137 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 138 | self.groups = groups 139 | self.base_width = width_per_group 140 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 141 | bias=False) 142 | self.bn1 = norm_layer(self.inplanes) 143 | self.relu = nn.ReLU(inplace=True) 144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 145 | self.layer1 = self._make_layer(block, 64, layers[0]) 146 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 147 | dilate=replace_stride_with_dilation[0]) 148 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 149 | dilate=replace_stride_with_dilation[1]) 150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 151 | dilate=replace_stride_with_dilation[2]) 152 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 153 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | # Zero-initialize the last BN in each residual branch, 163 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 164 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 165 | if zero_init_residual: 166 | for m in self.modules(): 167 | if isinstance(m, Bottleneck): 168 | nn.init.constant_(m.bn3.weight, 0) 169 | elif isinstance(m, BasicBlock): 170 | nn.init.constant_(m.bn2.weight, 0) 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 173 | norm_layer = self._norm_layer 174 | downsample = None 175 | previous_dilation = self.dilation 176 | if dilate: 177 | self.dilation *= stride 178 | stride = 1 179 | if stride != 1 or self.inplanes != planes * block.expansion: 180 | downsample = nn.Sequential( 181 | conv1x1(self.inplanes, planes * block.expansion, stride), 182 | norm_layer(planes * block.expansion), 183 | ) 184 | 185 | layers = [] 186 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 187 | self.base_width, previous_dilation, norm_layer)) 188 | self.inplanes = planes * block.expansion 189 | for _ in range(1, blocks): 190 | layers.append(block(self.inplanes, planes, groups=self.groups, 191 | base_width=self.base_width, dilation=self.dilation, 192 | norm_layer=norm_layer)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | x = self.conv1(x) 198 | x = self.bn1(x) 199 | x = self.relu(x) 200 | x = self.maxpool(x) 201 | 202 | x1 = self.layer1(x) 203 | x2 = self.layer2(x1) 204 | x3 = self.layer3(x2) 205 | x4 = self.layer4(x3) 206 | 207 | # x = self.avgpool(x) 208 | # x = torch.flatten(x, 1) 209 | # x = self.fc(x) 210 | # 返回从上到下 211 | return x4, x3, x2, x1 212 | 213 | 214 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 215 | model = ResNet(block, layers, **kwargs) 216 | if pretrained: 217 | state_dict = load_state_dict_from_url(model_urls[arch], 218 | progress=progress) 219 | model.load_state_dict(state_dict) 220 | return model 221 | 222 | 223 | def resnet18(pretrained=False, progress=True, **kwargs): 224 | r"""ResNet-18 model from 225 | `"Deep Residual Learning for Image Recognition" `_ 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | progress (bool): If True, displays a progress bar of the download to stderr 230 | """ 231 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 232 | **kwargs) 233 | 234 | 235 | def resnet34(pretrained=False, progress=True, **kwargs): 236 | r"""ResNet-34 model from 237 | `"Deep Residual Learning for Image Recognition" `_ 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 244 | **kwargs) 245 | 246 | 247 | def resnet50(pretrained=False, progress=True, **kwargs): 248 | r"""ResNet-50 model from 249 | `"Deep Residual Learning for Image Recognition" `_ 250 | 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 256 | **kwargs) 257 | 258 | 259 | def resnet101(pretrained=False, progress=True, **kwargs): 260 | r"""ResNet-101 model from 261 | `"Deep Residual Learning for Image Recognition" `_ 262 | 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 268 | **kwargs) 269 | 270 | 271 | def resnet152(pretrained=False, progress=True, **kwargs): 272 | r"""ResNet-152 model from 273 | `"Deep Residual Learning for Image Recognition" `_ 274 | 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 280 | **kwargs) 281 | 282 | 283 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 284 | r"""ResNeXt-50 32x4d model from 285 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 286 | 287 | Args: 288 | pretrained (bool): If True, returns a model pre-trained on ImageNet 289 | progress (bool): If True, displays a progress bar of the download to stderr 290 | """ 291 | kwargs['groups'] = 32 292 | kwargs['width_per_group'] = 4 293 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 294 | pretrained, progress, **kwargs) 295 | 296 | 297 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 298 | r"""ResNeXt-101 32x8d model from 299 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 300 | 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | progress (bool): If True, displays a progress bar of the download to stderr 304 | """ 305 | kwargs['groups'] = 32 306 | kwargs['width_per_group'] = 8 307 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 308 | pretrained, progress, **kwargs) 309 | 310 | 311 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 312 | r"""Wide ResNet-50-2 model from 313 | `"Wide Residual Networks" `_ 314 | 315 | The model is the same as ResNet except for the bottleneck number of channels 316 | which is twice larger in every block. The number of channels in outer 1x1 317 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 318 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | progress (bool): If True, displays a progress bar of the download to stderr 323 | """ 324 | kwargs['width_per_group'] = 64 * 2 325 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 326 | pretrained, progress, **kwargs) 327 | 328 | 329 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 330 | r"""Wide ResNet-101-2 model from 331 | `"Wide Residual Networks" `_ 332 | 333 | The model is the same as ResNet except for the bottleneck number of channels 334 | which is twice larger in every block. The number of channels in outer 1x1 335 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 336 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 337 | 338 | Args: 339 | pretrained (bool): If True, returns a model pre-trained on ImageNet 340 | progress (bool): If True, displays a progress bar of the download to stderr 341 | """ 342 | kwargs['width_per_group'] = 64 * 2 343 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 344 | pretrained, progress, **kwargs) 345 | --------------------------------------------------------------------------------