├── README.md ├── config.py ├── data ├── __init__.py ├── dataset.py ├── mincTestImages.txt └── mincTrainImages.txt ├── finetune.py ├── last.py └── models ├── BasicModule.py ├── FBC.py ├── SC.py ├── __init__.py └── biDDSCnips.py /README.md: -------------------------------------------------------------------------------- 1 | This repository is the implementation of AAAI 2020 paper: "Revisiting Bilinear Pooling: A Coding Perspective". 2 | 3 | Prerequisites 4 | ------- 5 | Our code requires PyTorch v1.0 and Python 2. 6 | 7 | Download the MINC dataset, and put it into the 'data' folder. 8 | 9 | Download the pretrained vgg-16 model, put it into the 'data' folder, and we name it as 'vgg16-397923af.pth' 10 | 11 | Training our model includes two steps. 12 | ------- 13 | 14 | Step 1: 15 | 16 | We train the new added layers. 17 | ``` 18 | python last.py -MINC True -data_path 'data/minc-2500/' -train_txt_path 'data/mincTrainImages.txt' -test_txt_path 'data/mincTestImages.txt' -rank 1 -k 2048 -beta 0.001 -pre_model_path 'data/vgg16-397923af.pth' -save_low_bound 99 19 | ``` 20 | 21 | 22 | Step 2: 23 | 24 | We train the whole network. 25 | ``` 26 | python finetune.py -MINC True -data_path 'data/minc-2500/' -train_txt_path 'data/mincTrainImages.txt' -test_txt_path 'data/mincTestImages.txt' -rank 1 -k 2048 -beta 0.001 -model_path 'data/vgg16-397923af.pth' 27 | ``` 28 | 29 | You can modify 'config.py' to set more detailed hyper-parameters. 30 | 31 | 32 | If this code is helpful, we'd appreciate it if you could cite our paper 33 | 34 | ``` 35 | @inproceedings{zhi2020revisiting, 36 | title={Revisiting Bilinear Pooling: A Coding Perspective}, 37 | author={Gao, Zhi and Wu, Yuwei and Zhang, Xiaoxun and Dai, Jindou and Jia, Yunde and Harandi, Mehrtash}, 38 | booktitle={Proceedings of AAAI Conference on Artificial Intelligence (AAAI)}, 39 | year={2020} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import warnings 3 | warnings.filterwarnings('ignore') 4 | 5 | class DefaultConfig(object): 6 | gpu_device = '0' 7 | #dataset #class train test 8 | DTD = True #47 3760 1880 9 | CUB = False #200 5994 5794 10 | INDOOR = False #23 51750 5750 11 | MINC2500 = False # 670 330 12 | 13 | RANK_ATOMS = 1 14 | NUM_CLUSTER = 2048 15 | BETA = 0.001 16 | model_name_pre = 'model_name' 17 | model_path = None ## the path of the pretrained model 18 | save_low_bound = 79 ##when the accuracy achieves save_low_bound, the model is saved 19 | 20 | res_plus = 512 21 | res = 448 22 | train_print_freq = 256 23 | 24 | lr = 0.01 25 | lr_scale = 0.1 26 | lr_freq_list = [40,80] 27 | 28 | train_bs = 16 29 | down_chennel = 512 30 | test_bs = 4 31 | test_epoch = 1 32 | pretrained = True 33 | pre_path = 'data/vgg16-397923af.pth' 34 | 35 | model_name = 'FBC' 36 | use_gpu = True 37 | 38 | if MINC2500: 39 | data_path = 'data/minc-2500/' 40 | train_txt_path = 'data/mincTrainImages.txt' 41 | test_txt_path = 'data/mincTestImages.txt' 42 | class_num = 23 43 | else: 44 | print('data error') 45 | 46 | max_epoches = 500 47 | 48 | 49 | opt = DefaultConfig() -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import torch 4 | import numpy as np 5 | from torchvision import transforms,utils 6 | from torch.utils.data import Dataset,DataLoader 7 | from PIL import Image 8 | from config import opt 9 | 10 | def default_loader(path): 11 | #print(path) 12 | return Image.open(path).convert('RGB') 13 | 14 | class MyDataset(Dataset): 15 | def __init__(self, txt, transform, loader=default_loader): 16 | fh = open(txt, 'r') 17 | imgs = [] 18 | for line in fh: 19 | line = line.strip('\n') 20 | line = line.rstrip() 21 | words = line.split() 22 | if len(words) == 2: 23 | imgs.append((words[0], int(words[1]))) 24 | elif len(words) == 3: 25 | tmp = words[0] + ' ' + words[1] 26 | imgs.append((tmp, int(words[-1]))) 27 | 28 | self.imgs = imgs 29 | self.transform = transform 30 | self.loader = loader 31 | 32 | def __getitem__(self, index): 33 | fn, label = self.imgs[index] 34 | data_pth = opt.data_path 35 | #img = self.loader('/data/guijun/caffe-20160312/examples/compact_bilinear/cub/images/' + fn) 36 | img = self.loader(data_pth + fn) 37 | img = self.transform(img) 38 | 39 | return img, label 40 | 41 | def __len__(self): 42 | return len(self.imgs) 43 | 44 | 45 | 46 | 47 | 48 | # import os, sys 49 | # sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) 50 | # from PIL import Image 51 | # from torch.utils import data 52 | # import numpy as np 53 | # import torch 54 | # #from torchvision import transforms as T 55 | # import sys, os 56 | # from config import opt 57 | # import matplotlib.pyplot as plt 58 | # import six.moves.cPickle as pickle 59 | # import random 60 | # import warnings 61 | # warnings.filterwarnings('ignore') 62 | 63 | # class RGBD(data.Dataset): 64 | 65 | # def __init__(self, root, train=False, val=False, test=False, split=opt.split): 66 | # ''' 67 | # get the data and split them into train, val and test subset; 68 | # ''' 69 | # print('split:', split) 70 | # self.train = train 71 | # self.val = val 72 | # self.test = test 73 | 74 | # imgs = [] 75 | # #stage = '1' if train else '2' if val else '3' if test else '4' 76 | # #val -> train ,discard the val 77 | # stage = '1' if train else '2' if val else '3' if test else '4' 78 | # print('stage: ',('train' if stage=='1' else 'val' 79 | # if stage=='2' else 'test' if stage=='3' else 'UNKNOWNSTAGE')) 80 | # #print('csvroot:',root) 81 | # f = open(root, 'r') 82 | 83 | # f.readline() 84 | # lines = f.readlines() 85 | # for line in lines: 86 | # contents = line.strip('\n\r').split(',') 87 | # #print('contents',contents) 88 | # #img_name cat_label ins_label ins_set split 0-9 89 | 90 | # if stage == '1':#train 91 | # if contents[4+split] == '1': 92 | # if int(contents[0].split('_')[-2])%opt.train_n_th != 1: 93 | # continue 94 | # #print(int(contents[0].split('_')[-2])) 95 | # item = {'rgb_path' : contents[0], 96 | # 'depth_path' : contents[0][:-8] + 'depthcrop.png', 97 | # 'mask_path' : contents[0][:-8] + 'maskcrop.png', 98 | # 'cat_label' : int(contents[1]), 99 | # } 100 | # imgs.append(item) 101 | # elif stage == '3':#test 102 | # if contents[4+split] == '3' : 103 | # if int(contents[0].split('_')[-2])%opt.test_n_th != 1: 104 | # continue 105 | # item = {'rgb_path' : contents[0], 106 | # 'depth_path' : contents[0][:-8] + 'depthcrop.png', 107 | # 'mask_path' : contents[0][:-8] + 'maskcrop.png', 108 | # 'cat_label' : int(contents[1]), 109 | # } 110 | # imgs.append(item) 111 | 112 | 113 | 114 | # self.imgs = imgs 115 | 116 | 117 | # def __getitem__(self, index): 118 | # ''' 119 | # return a picture according to the given index once time; 120 | # ''' 121 | # rgb_path = '/'.join(opt.csv_rgb256_path.split('/')[:-1]) + '/' + self.imgs[index]['rgb_path'] 122 | # rgb_data = Image.open(rgb_path) 123 | # if self.train: 124 | # rgb_flip_random = random.random() 125 | # if rgb_flip_random > 0.5: 126 | # rgb_data = rgb_data.transpose(Image.FLIP_LEFT_RIGHT) 127 | # with open(opt.rgb_mean_path, 'rb') as f: 128 | # rgb_mean = pickle.load(f) 129 | # rgb_mean = torch.from_numpy(rgb_mean) 130 | # rgb_data = np.asarray(rgb_data) 131 | # rgb_data = np.transpose(rgb_data ,(2, 0, 1)) 132 | # rgb_data = torch.from_numpy(rgb_data).double() 133 | # rgb_data = (rgb_data-rgb_mean) 134 | 135 | # rgb_data = rgb_data/255.0 136 | # left = random.randint(0, opt.resolutionPlus-opt.resolution) 137 | # top = random.randint(0, opt.resolutionPlus-opt.resolution) 138 | # rgb_data = rgb_data[:,left:left+opt.resolution,top:top+opt.resolution].float() 139 | 140 | 141 | # depth_path = '/'.join(opt.csv_depth256_path.split('/')[:-1]) + '/' + self.imgs[index]['depth_path'] 142 | # depth_data = Image.open(depth_path) 143 | # if self.train: 144 | # depth_flip_random = random.random() 145 | # if depth_flip_random > 0.5: 146 | # depth_data = depth_data.transpose(Image.FLIP_LEFT_RIGHT) 147 | # with open(opt.depth_mean_path, 'rb') as f: 148 | # depth_mean = pickle.load(f) 149 | # depth_mean = torch.from_numpy(depth_mean) 150 | # depth_data = np.asarray(depth_data) 151 | # depth_data = np.transpose(depth_data ,(2, 0, 1)) 152 | # depth_data = torch.from_numpy(depth_data).double() 153 | # depth_data = (depth_data - depth_mean) 154 | # depth_data = depth_data/255.0 155 | # left = random.randint(0, opt.resolutionPlus-opt.resolution) 156 | # top = random.randint(0, opt.resolutionPlus-opt.resolution) 157 | 158 | # #depth_data = torch.FloatTensor(depth_data[:,left:left+opt.resolution,top:top+opt.resolution]) 159 | # depth_data = depth_data[:,left:left+opt.resolution,top:top+opt.resolution].float() 160 | 161 | # #print(depth_data) 162 | # label = self.imgs[index]['cat_label'] 163 | # return (rgb_data, depth_data), label 164 | 165 | 166 | 167 | # def __len__(self): 168 | # return len(self.imgs) 169 | 170 | # if __name__ == '__main__': 171 | # root = opt.csv_path 172 | # # for i in range(10): 173 | # # rgbd = RGBD(root,train=True,split=i) 174 | # # print('train len:',len(rgbd.imgs)) 175 | # # rgbd = RGBD(root,test=True,split=i) 176 | # # print('test len:',len(rgbd.imgs)) 177 | # rgbd = RGBD(root,train=True,split=0) 178 | # print('train len:',len(rgbd.imgs)) 179 | # a = rgbd.imgs[:10] 180 | # # for i in a: 181 | # # print(i) 182 | # #print(label) 183 | # #print(len(rgbd)) -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import torch 3 | import torchvision 4 | import torch.optim as optim 5 | import torchvision.transforms as transforms 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import data 9 | from collections import OrderedDict 10 | import os 11 | import torch.backends.cudnn as cudnn 12 | import math 13 | from config import opt 14 | import time 15 | from data.dataset import MyDataset 16 | from models import FBC 17 | import numpy as np 18 | import scipy.io as sio 19 | from PIL import Image 20 | import random 21 | import argparse 22 | import sys 23 | 24 | acc_list = [0.0] 25 | loss_list = [0.0] 26 | criterion = nn.CrossEntropyLoss() 27 | 28 | def train(epoch, lr): 29 | print('model_name_pre:',args.model_name_pre) 30 | print('bs',args.train_bs) 31 | print('SC beta:', args.BETA) 32 | print('rank:',args.RANK_ATOMS) 33 | print('num cluster:',args.NUM_CLUSTER) 34 | print('save_low_bound:',args.save_low_bound) 35 | print('weight_decay:',args.weight_decay) 36 | if args.DTD: 37 | print('dataset:','DTD') 38 | elif args.Aircraft: 39 | print('dataset:','Aircraft') 40 | elif args.CUB: 41 | print('dataset:','CUB') 42 | elif args.INDOOR: 43 | print('dataset:','INDOOR') 44 | elif args.MINC2500: 45 | print('dataset:','MINC2500') 46 | epoch_start = time.time() 47 | 48 | features_lr = lr * 0.1 49 | if features_lr <= 0.0001: 50 | features_lr = 0.0001 51 | optimizer = optim.SGD( 52 | [ 53 | {'params': model.features.parameters(), 'lr':features_lr}, 54 | {'params': model.Linear_dataproj_k.parameters(), 'lr': lr}, 55 | {'params': model.Linear_dataproj2_k.parameters(), 'lr': lr}, 56 | {'params': model.Linear_predict.parameters(),'lr':lr}, 57 | ], 58 | lr=lr, momentum=0.9, weight_decay=args.weight_decay) 59 | 60 | model.train() 61 | start = time.time() 62 | running_loss = 0.0 63 | 64 | train_bs = args.train_bs 65 | train_len = len(trainset) 66 | for batch_idx, (data, target) in enumerate(trainloader): 67 | if (batch_idx+1) * train_bs > train_len: 68 | break 69 | data = Variable(data) 70 | target = Variable(target) 71 | data, target = data.cuda(), target.cuda() 72 | optimizer.zero_grad() 73 | output = model(data) 74 | loss = criterion(output, target) 75 | 76 | loss.backward() 77 | running_loss += loss.data.item() 78 | optimizer.step() 79 | if batch_idx % (args.train_print_freq/args.train_bs) == 0 and batch_idx != 0: 80 | loss_tmp = running_loss / (args.train_print_freq/args.train_bs) #div the n of batch 81 | interval = time.time() - start 82 | start = time.time() 83 | print('Epoch:{}[{}/{} ]\tLoss:{:.6f}\tLR:{}\tbeta:{}\ttime:{:.2f}'.format( 84 | epoch, batch_idx * len(data), train_len, loss_tmp, lr, model.sc.beta, interval/60)) 85 | running_loss = 0.0 86 | epoch_end = time.time() 87 | tmp = (epoch_end - epoch_start) / 60 88 | print('train time:{:.4f} min'.format(tmp)) 89 | 90 | def test(): 91 | model.eval() 92 | test_loss = 0 93 | correct = 0 94 | start = time.time() 95 | test_bs = args.test_bs 96 | test_len = len(testset) 97 | for batch_idx, (data, target) in enumerate(testloader): 98 | if (batch_idx+1) * test_bs > test_len: 99 | break 100 | data = Variable(data) 101 | target = Variable(target) 102 | data, target = data.cuda(), target.cuda() 103 | output = model(data) 104 | test_loss += criterion(output, target).data.item() 105 | pred = output.data.max(1, keepdim=True)[1] 106 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 107 | test_loss = test_loss / (test_len / args.test_bs) 108 | loss_list.append(round(test_loss, 4)) 109 | acc = 100.0 * float(correct) / test_len 110 | acc = round(acc, 4) 111 | interval = time.time() - start 112 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\ttime:{:.2f}\n'.format( 113 | test_loss , correct, test_len, acc, interval/60)) 114 | 115 | model_name = './tmp/' + args.model_name_pre + str(acc) + 'lr_' + str(lr) + '.pth' 116 | acc_max = max(acc_list) 117 | if acc > acc_max and acc > args.save_low_bound: 118 | torch.save(model.state_dict(), model_name) 119 | print('i have saved the model') 120 | 121 | acc_list.append(acc) 122 | acc_max = max(acc_list) 123 | print('max acc:', acc_max) 124 | print('acc list:', acc_list) 125 | print('loss list:', loss_list) 126 | 127 | def parse_args(*args): 128 | parser = argparse.ArgumentParser() 129 | 130 | parser.add_argument('-DTD','--DTD', default=opt.DTD) 131 | parser.add_argument('-CUB','--CUB', default=opt.CUB) 132 | parser.add_argument('-INDOOR','--INDOOR', default=opt.INDOOR) 133 | parser.add_argument('-MINC2500','--MINC2500', default=opt.MINC2500) 134 | parser.add_argument('-data_path','--data_path', default=opt.data_path) 135 | parser.add_argument('-train_txt_path','--train_txt_path', default=opt.train_txt_path) 136 | parser.add_argument('-test_txt_path','--test_txt_path', default=opt.test_txt_path) 137 | parser.add_argument('-class_num','--class_num', default=opt.class_num) 138 | parser.add_argument('-res_plus','--res_plus', type=int, default=opt.res_plus) 139 | parser.add_argument('-res','--res', type=int, default=opt.res) 140 | parser.add_argument('-lr','--lr', type=float, default=0.01) 141 | parser.add_argument('-lr_scale','--lr_scale', type=float, default=opt.lr_scale) 142 | parser.add_argument('-train_bs','--train_bs', type=int, default=opt.train_bs) 143 | 144 | parser.add_argument('-device','--gpu_device', default=opt.gpu_device) 145 | parser.add_argument('-rank','--RANK_ATOMS', type=int, default=opt.RANK_ATOMS) 146 | parser.add_argument('-k','--NUM_CLUSTER', type=int, default=opt.NUM_CLUSTER) 147 | parser.add_argument('-beta','--BETA', type=float, default=opt.BETA) 148 | parser.add_argument('-model_name_pre','--model_name_pre', default=opt.model_name_pre) 149 | parser.add_argument('-model_path','--model_path', default=opt.model_path) 150 | parser.add_argument('-save_low_bound','--save_low_bound', type=float, default=opt.save_low_bound) 151 | parser.add_argument('-weight_decay','--weight_decay', type=float, default=5e-4) 152 | parser.add_argument('-train_print_freq','--train_print_freq', type=int, default=opt.train_print_freq) 153 | parser.add_argument('-test_bs','--test_bs', type=int, default=opt.test_bs) 154 | parser.add_argument('-test_epoch','--test_epoch', type=int, default=opt.test_epoch) 155 | parser.add_argument('-pretrained','--pretrained', default=opt.pretrained) 156 | parser.add_argument('-pre_model_path','--pre_path', default=opt.pre_path) 157 | parser.add_argument('-model_name','--model_name', default=opt.model_name) 158 | parser.add_argument('-use_gpu','--use_gpu', default=opt.use_gpu) 159 | parser.add_argument('-max_epoches','--max_epoches', type=int, default=opt.max_epoches) 160 | 161 | args = parser.parse_args() 162 | return args 163 | 164 | 165 | def main(argv): 166 | global args 167 | global model 168 | args = parse_args(argv) 169 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device 170 | if args.model_name == 'FBC': 171 | print('model:','FBC') 172 | model = FBC() 173 | else: 174 | print('model error') 175 | model.cuda() 176 | 177 | if args.model_path: 178 | print('i am load model', args.model_path) 179 | pre_model = torch.load(args.model_path) 180 | model_dict = model.state_dict() 181 | pre_dict = {k:v for k, v in pre_model.items() if k in model_dict} 182 | print('pre dict len:',len(pre_dict)) 183 | model_dict.update(pre_dict) 184 | model.load_state_dict(model_dict) 185 | elif args.pretrained: 186 | print('l am loading pre model', args.pre_path) 187 | pre_model = torch.load(args.pre_path) 188 | model_dict = model.state_dict() 189 | pre_dict = {k:v for k, v in pre_model.items() if k in model_dict} 190 | print('pre dict len:',len(pre_dict)) 191 | model_dict.update(pre_dict) 192 | model.load_state_dict(model_dict) 193 | else: 194 | for m in model.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 197 | m.weight.data.normal_(0, math.sqrt(2. / n)) 198 | elif isinstance(m, nn.BatchNorm2d): 199 | m.weight.data.fill_(1) 200 | m.bias.data.zero_() 201 | 202 | if True: 203 | train_txt_path = args.train_txt_path 204 | global trainset 205 | trainset = MyDataset(train_txt_path, transform=transforms.Compose([ 206 | #transforms.Scale((args.res_plus,args.res_plus)), 207 | transforms.Scale(args.res_plus), 208 | transforms.RandomHorizontalFlip(), 209 | transforms.RandomCrop(args.res), 210 | transforms.ToTensor(), 211 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 212 | ])) 213 | train_len = len(trainset) 214 | print('train_len:',train_len) 215 | train_bs = args.train_bs 216 | global trainloader 217 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True) 218 | test_txt_path = args.test_txt_path 219 | global testset 220 | testset = MyDataset(test_txt_path, transform=transforms.Compose([ 221 | #transforms.Scale((args.res_plus,args.res_plus)), 222 | transforms.Scale(args.res_plus), 223 | transforms.CenterCrop(args.res), 224 | transforms.ToTensor(), 225 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 226 | ])) 227 | test_len = len(testset) 228 | print('test_len:',test_len) 229 | test_bs = args.test_bs 230 | global testloader 231 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False) 232 | 233 | lr = args.lr 234 | mode = True #1 : train 0: test 235 | if mode: 236 | for epoch in range(1, args.max_epoches): 237 | if epoch in opt.lr_freq_list: 238 | lr = lr * args.lr_scale 239 | lr = max(lr, 0.0001) 240 | train(epoch, lr) 241 | if epoch % args.test_epoch == 0: 242 | test() 243 | else: 244 | test() 245 | 246 | if __name__ == '__main__': 247 | main(sys.argv[1:]) 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /last.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import torch 3 | import torchvision 4 | import torch.optim as optim 5 | import torchvision.transforms as transforms 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import data 9 | from collections import OrderedDict 10 | import os 11 | import torch.backends.cudnn as cudnn 12 | import math 13 | from config import opt 14 | import time 15 | from data.dataset import MyDataset 16 | from models import FBC 17 | import numpy as np 18 | import scipy.io as sio 19 | from PIL import Image 20 | import random 21 | import argparse 22 | import sys 23 | 24 | acc_list = [0.0] 25 | loss_list = [0.0] 26 | criterion = nn.CrossEntropyLoss() 27 | 28 | def train(epoch, lr): 29 | print('model_name_pre:',args.model_name_pre) 30 | print('bs',args.train_bs) 31 | print('SC beta:', args.BETA) 32 | print('rank:',args.RANK_ATOMS) 33 | print('num cluster:',args.NUM_CLUSTER) 34 | print('save_low_bound:',args.save_low_bound) 35 | print('weight_decay:',args.weight_decay) 36 | if args.DTD: 37 | print('dataset:','DTD') 38 | elif args.Aircraft: 39 | print('dataset:','Aircraft') 40 | elif args.CUB: 41 | print('dataset:','CUB') 42 | elif args.INDOOR: 43 | print('dataset:','INDOOR') 44 | elif args.MINC2500: 45 | print('dataset:','MINC2500') 46 | epoch_start = time.time() 47 | 48 | features_lr = lr * 0.1 49 | if features_lr <= 0.0001: 50 | features_lr = 0.0001 51 | optimizer = optim.SGD( 52 | [ 53 | #{'params': model.features.parameters(), 'lr':features_lr}, 54 | {'params': model.Linear_dataproj_k.parameters(), 'lr': lr}, 55 | {'params': model.Linear_dataproj2_k.parameters(), 'lr': lr}, 56 | {'params': model.Linear_predict.parameters(),'lr':lr}, 57 | ], 58 | lr=lr, momentum=0.9, weight_decay=args.weight_decay) 59 | 60 | model.train() 61 | start = time.time() 62 | running_loss = 0.0 63 | 64 | train_bs = args.train_bs 65 | train_len = len(trainset) 66 | for batch_idx, (data, target) in enumerate(trainloader): 67 | if (batch_idx+1) * train_bs > train_len: 68 | break 69 | data = Variable(data) 70 | target = Variable(target) 71 | data, target = data.cuda(), target.cuda() 72 | optimizer.zero_grad() 73 | output = model(data) 74 | loss = criterion(output, target) 75 | 76 | loss.backward() 77 | running_loss += loss.data.item() 78 | optimizer.step() 79 | if batch_idx % (args.train_print_freq/args.train_bs) == 0 and batch_idx != 0: 80 | loss_tmp = running_loss / (args.train_print_freq/args.train_bs) #div the n of batch 81 | interval = time.time() - start 82 | start = time.time() 83 | print('Epoch:{}[{}/{} ]\tLoss:{:.6f}\tLR:{}\tbeta:{}\ttime:{:.2f}'.format( 84 | epoch, batch_idx * len(data), train_len, loss_tmp, lr, model.sc.beta, interval/60)) 85 | running_loss = 0.0 86 | epoch_end = time.time() 87 | tmp = (epoch_end - epoch_start) / 60 88 | print('train time:{:.4f} min'.format(tmp)) 89 | 90 | def test(): 91 | model.eval() 92 | test_loss = 0 93 | correct = 0 94 | start = time.time() 95 | test_bs = args.test_bs 96 | test_len = len(testset) 97 | for batch_idx, (data, target) in enumerate(testloader): 98 | if (batch_idx+1) * test_bs > test_len: 99 | break 100 | data = Variable(data) 101 | target = Variable(target) 102 | data, target = data.cuda(), target.cuda() 103 | output = model(data) 104 | test_loss += criterion(output, target).data.item() 105 | pred = output.data.max(1, keepdim=True)[1] 106 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 107 | test_loss = test_loss / (test_len / args.test_bs) 108 | loss_list.append(round(test_loss, 4)) 109 | acc = 100.0 * float(correct) / test_len 110 | acc = round(acc, 4) 111 | interval = time.time() - start 112 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\ttime:{:.2f}\n'.format( 113 | test_loss , correct, test_len, acc, interval/60)) 114 | 115 | model_name = './tmp/' + args.model_name_pre + str(acc) + 'lr_' + str(lr) + '.pth' 116 | acc_max = max(acc_list) 117 | if acc > acc_max and acc > args.save_low_bound: 118 | torch.save(model.state_dict(), model_name) 119 | print('i have saved the model') 120 | 121 | acc_list.append(acc) 122 | acc_max = max(acc_list) 123 | print('max acc:', acc_max) 124 | print('acc list:', acc_list) 125 | print('loss list:', loss_list) 126 | 127 | def parse_args(*args): 128 | parser = argparse.ArgumentParser() 129 | 130 | parser.add_argument('-DTD','--DTD', default=opt.DTD) 131 | parser.add_argument('-CUB','--CUB', default=opt.CUB) 132 | parser.add_argument('-INDOOR','--INDOOR', default=opt.INDOOR) 133 | parser.add_argument('-MINC2500','--MINC2500', default=opt.MINC2500) 134 | parser.add_argument('-data_path','--data_path', default=opt.data_path) 135 | parser.add_argument('-train_txt_path','--train_txt_path', default=opt.train_txt_path) 136 | parser.add_argument('-test_txt_path','--test_txt_path', default=opt.test_txt_path) 137 | parser.add_argument('-class_num','--class_num', default=opt.class_num) 138 | parser.add_argument('-res_plus','--res_plus', type=int, default=opt.res_plus) 139 | parser.add_argument('-res','--res', type=int, default=opt.res) 140 | parser.add_argument('-lr','--lr', type=float, default=1.0) 141 | parser.add_argument('-lr_scale','--lr_scale', type=float, default=opt.lr_scale) 142 | parser.add_argument('-train_bs','--train_bs', type=int, default=opt.train_bs) 143 | 144 | parser.add_argument('-device','--gpu_device', default=opt.gpu_device) 145 | parser.add_argument('-rank','--RANK_ATOMS', type=int, default=opt.RANK_ATOMS) 146 | parser.add_argument('-k','--NUM_CLUSTER', type=int, default=opt.NUM_CLUSTER) 147 | parser.add_argument('-beta','--BETA', type=float, default=opt.BETA) 148 | parser.add_argument('-model_name_pre','--model_name_pre', default=opt.model_name_pre) 149 | parser.add_argument('-model_path','--model_path', default=opt.model_path) 150 | parser.add_argument('-save_low_bound','--save_low_bound', type=float, default=opt.save_low_bound) 151 | parser.add_argument('-weight_decay','--weight_decay', type=float, default=5e-6) 152 | parser.add_argument('-train_print_freq','--train_print_freq', type=int, default=opt.train_print_freq) 153 | parser.add_argument('-test_bs','--test_bs', type=int, default=opt.test_bs) 154 | parser.add_argument('-test_epoch','--test_epoch', type=int, default=opt.test_epoch) 155 | parser.add_argument('-pretrained','--pretrained', default=opt.pretrained) 156 | parser.add_argument('-pre_model_path','--pre_path', default=opt.pre_path) 157 | parser.add_argument('-model_name','--model_name', default=opt.model_name) 158 | parser.add_argument('-use_gpu','--use_gpu', default=opt.use_gpu) 159 | parser.add_argument('-max_epoches','--max_epoches', type=int, default=opt.max_epoches) 160 | 161 | args = parser.parse_args() 162 | return args 163 | 164 | def main(argv): 165 | global args 166 | global model 167 | args = parse_args(argv) 168 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device 169 | 170 | if args.model_name == 'FBC': 171 | print('model:','FBC') 172 | model = FBC() 173 | else: 174 | print('model error') 175 | model.cuda() 176 | 177 | if args.model_path: 178 | print('i am load model', args.model_path) 179 | pre_model = torch.load(args.model_path) 180 | model_dict = model.state_dict() 181 | pre_dict = {k:v for k, v in pre_model.items() if k in model_dict} 182 | print('pre dict len:',len(pre_dict)) 183 | model_dict.update(pre_dict) 184 | model.load_state_dict(model_dict) 185 | elif args.pretrained: 186 | print('l am loading pre model', args.pre_path) 187 | pre_model = torch.load(args.pre_path) 188 | model_dict = model.state_dict() 189 | pre_dict = {k:v for k, v in pre_model.items() if k in model_dict} 190 | print('pre dict len:',len(pre_dict)) 191 | model_dict.update(pre_dict) 192 | model.load_state_dict(model_dict) 193 | else: 194 | for m in model.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 197 | m.weight.data.normal_(0, math.sqrt(2. / n)) 198 | elif isinstance(m, nn.BatchNorm2d): 199 | m.weight.data.fill_(1) 200 | m.bias.data.zero_() 201 | 202 | if True: 203 | train_txt_path = args.train_txt_path 204 | global trainset 205 | trainset = MyDataset(train_txt_path, transform=transforms.Compose([ 206 | #transforms.Scale((args.res_plus,args.res_plus)), 207 | transforms.Scale(args.res_plus), 208 | transforms.RandomHorizontalFlip(), 209 | transforms.RandomCrop(args.res), 210 | transforms.ToTensor(), 211 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 212 | ])) 213 | train_len = len(trainset) 214 | print('train_len:',train_len) 215 | train_bs = args.train_bs 216 | global trainloader 217 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True) 218 | test_txt_path = args.test_txt_path 219 | global testset 220 | testset = MyDataset(test_txt_path, transform=transforms.Compose([ 221 | #transforms.Scale((args.res_plus,args.res_plus)), 222 | transforms.Scale(args.res_plus), 223 | transforms.CenterCrop(args.res), 224 | transforms.ToTensor(), 225 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 226 | ])) 227 | test_len = len(testset) 228 | print('test_len:',test_len) 229 | test_bs = args.test_bs 230 | global testloader 231 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False) 232 | 233 | lr = args.lr 234 | mode = True #1 : train 0: test 235 | if mode: 236 | for epoch in range(1, args.max_epoches): 237 | if epoch in opt.lr_freq_list: 238 | lr = lr * args.lr_scale 239 | lr = max(lr, 0.0001) 240 | train(epoch, lr) 241 | if epoch % args.test_epoch == 0: 242 | test() 243 | else: 244 | test() 245 | 246 | 247 | if __name__ == '__main__': 248 | main(sys.argv[1:]) 249 | 250 | 251 | 252 | 253 | -------------------------------------------------------------------------------- /models/BasicModule.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import time 4 | import os 5 | import sys, os 6 | sys.path.append(os.path.abspath(os.path.dirname(__file__)+os.path.sep+"..")) 7 | from config import opt 8 | 9 | class BasicModule(torch.nn.Module): 10 | ''' 11 | 封装了nn.Module,主要是提供了save和load两个方法 12 | ''' 13 | 14 | def __init__(self): 15 | super(BasicModule,self).__init__() 16 | #self.model_name=str(type(self))# 默认名字 17 | self.model_name=str(type(self)).strip('<>').split('.')[-1][:-1] 18 | # print(self.model_name) 19 | # print(os.getcwd()[:-6]) 20 | 21 | def load(self, path): 22 | ''' 23 | 可加载指定路径的模型 24 | ''' 25 | #print('i am loading the ' + path) 26 | self.load_state_dict(torch.load(path)) 27 | 28 | def save(self, acc=None, lr=None): 29 | ''' 30 | 保存模型,默认使用“模型名字+时间”作为文件名 31 | ''' 32 | prefix = opt.model_save_path + self.model_name + '_' 33 | #print(prefix) 34 | #print(os.getcwd()) 35 | name = prefix + time.strftime('%m%d_%H%M_') + '_acc_' \ 36 | + str(acc).replace('.','_') + 'lr' + lr + '.pth' 37 | #print(name) 38 | torch.save(self.state_dict(), name) 39 | return name 40 | 41 | 42 | 43 | if __name__ == '__main__': 44 | a = BasicModule() 45 | 46 | name = a.save() 47 | print(name) -------------------------------------------------------------------------------- /models/FBC.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | from torch import nn 3 | import torch 4 | #from .BasicModule import BasicModule 5 | from .BasicModule import BasicModule 6 | import torch.nn.functional as F 7 | import math 8 | import sys, os 9 | sys.path.append(os.path.abspath(os.path.dirname(__file__)+os.path.sep+"..")) 10 | from config import opt 11 | import scipy.io as sio 12 | from torch.autograd import Variable 13 | from torch.nn import Parameter as Parameter 14 | from .SC import SC 15 | import torchvision 16 | 17 | class FBC(BasicModule): 18 | def __init__(self): 19 | super(FBC, self).__init__() 20 | self.features = torchvision.models.vgg16(pretrained=False).features 21 | self.features = torch.nn.Sequential(*list(self.features.children())[:-1]) # Remove pool5. 22 | 23 | self.device=torch.device("cuda") 24 | self.output_dim = self.JOINT_EMB_SIZE = opt.RANK_ATOMS * opt.NUM_CLUSTER #20*2048 25 | self.input_dim = opt.down_chennel 26 | 27 | self.Linear_dataproj_k = nn.Linear(opt.down_chennel, self.JOINT_EMB_SIZE) 28 | self.Linear_dataproj2_k = nn.Linear(opt.down_chennel, self.JOINT_EMB_SIZE) 29 | 30 | self.Linear_predict = nn.Linear(opt.NUM_CLUSTER, opt.class_num) 31 | 32 | self.sc = SC(beta=opt.BETA) 33 | if opt.res == 224: 34 | self.Avgpool = nn.AvgPool1d(kernel_size=196) 35 | elif opt.res == 448: 36 | self.Avgpool = nn.AvgPool1d(kernel_size=784) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | nn.init.xavier_normal_(m.weight.data,) 41 | m.bias.data.zero_() 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.Linear): 46 | nn.init.xavier_normal_(m.weight.data) 47 | m.bias.data.zero_() 48 | 49 | 50 | def forward(self, x): 51 | x = self.features(x) 52 | bs, c, w, h = x.shape[0:4] 53 | 54 | bswh = bs*w*h 55 | x = x.permute(0,2,3,1) 56 | x = x.contiguous().view(-1,c) 57 | 58 | x1 = self.Linear_dataproj_k(x) 59 | x2 = self.Linear_dataproj2_k(x) 60 | 61 | bi = x1.mul(x2) 62 | 63 | bi = bi.view(-1, 1, opt.NUM_CLUSTER, opt.RANK_ATOMS) 64 | bi = torch.squeeze(torch.sum(bi, 3)) 65 | 66 | bi = self.sc(bi) 67 | 68 | bi = bi.view(bs,h*w,-1) 69 | bi = bi.permute(0,2,1) 70 | bi = torch.squeeze(self.Avgpool(bi)) 71 | 72 | bi = torch.sqrt(F.relu(bi)) - torch.sqrt(F.relu(-bi)) 73 | bi = F.normalize(bi, p=2, dim=1) 74 | 75 | y = self.Linear_predict(bi) 76 | return y 77 | 78 | 79 | if __name__ == '__main__': 80 | print(1) 81 | 82 | -------------------------------------------------------------------------------- /models/SC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable as V 4 | import torch.nn.functional as F 5 | from config import opt 6 | 7 | class SC(nn.Module): 8 | def __init__(self,beta): 9 | super(SC, self).__init__() 10 | self.device=torch.device("cuda") 11 | self.beta=beta 12 | # if opt.learn_beta: 13 | # self.beta = nn.Parameter(torch.tensor(beta)) 14 | # else: 15 | # self.beta = beta 16 | #self.B=nn.Parameter(torch.randn(10,20))#c*c - > num_cluster 17 | 18 | 19 | 20 | 21 | def forward(self, input): 22 | 23 | zero = torch.zeros(input.shape).to(self.device) 24 | output = torch.mul(torch.sign(input),torch.max((torch.abs(input)-self.beta/2),zero)) 25 | 26 | 27 | return output 28 | 29 | if __name__ == '__main__': 30 | a = SC(1) 31 | 32 | input = torch.randn(2,3) 33 | print(input) 34 | out = a(input) 35 | print(out) 36 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .FBC import FBC 2 | -------------------------------------------------------------------------------- /models/biDDSCnips.py: -------------------------------------------------------------------------------- 1 | 2 | #coding:utf-8 3 | from torch import nn 4 | import torch 5 | #from .BasicModule import BasicModule 6 | from .BasicModule import BasicModule 7 | import torch.nn.functional as F 8 | import math 9 | import sys, os 10 | sys.path.append(os.path.abspath(os.path.dirname(__file__)+os.path.sep+"..")) 11 | from config import opt 12 | import utils.myUtil as utils 13 | import scipy.io as sio 14 | from torch.autograd import Variable 15 | from torch.nn import Parameter as Parameter 16 | from .SC import SC 17 | 18 | class FBC(BasicModule): 19 | def __init__(self): 20 | super(FBC, self).__init__() 21 | self.features = torchvision.models.vgg16(pretrained=False).features 22 | self.features = torch.nn.Sequential(*list(self.features.children())[:-1]) # Remove pool5. 23 | 24 | self.device=torch.device("cuda") 25 | self.output_dim = self.JOINT_EMB_SIZE = opt.RANK_ATOMS * opt.NUM_CLUSTER #20*2048 26 | self.input_dim = opt.down_chennel 27 | 28 | self.Linear_dataproj_k = nn.Linear(opt.down_chennel, self.JOINT_EMB_SIZE) 29 | self.Linear_dataproj2_k = nn.Linear(opt.down_chennel, self.JOINT_EMB_SIZE) 30 | 31 | self.Linear_predict = nn.Linear(opt.NUM_CLUSTER, opt.class_num) 32 | 33 | self.sc = SC(beta=opt.BETA) 34 | if opt.res == 224: 35 | self.Avgpool = nn.AvgPool1d(kernel_size=196) 36 | print('avg 196') 37 | elif opt.res == 448: 38 | print('avg 448') 39 | self.Avgpool = nn.AvgPool1d(kernel_size=784) 40 | 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | nn.init.xavier_normal_(m.weight.data,) 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.BatchNorm2d): 46 | m.weight.data.fill_(1) 47 | m.bias.data.zero_() 48 | elif isinstance(m, nn.Linear): 49 | nn.init.xavier_normal_(m.weight.data) 50 | m.bias.data.zero_() 51 | 52 | def forward(self, x): 53 | x = self.features(x) 54 | bs, c, w, h = x.shape[0:4] 55 | 56 | bswh = bs*w*h 57 | x = x.permute(0,2,3,1) 58 | x = x.contiguous().view(-1,c) 59 | 60 | x1 = self.Linear_dataproj_k(x) 61 | x2 = self.Linear_dataproj2_k(x) 62 | bi = x1.mul(x2) 63 | 64 | bi = bi.view(-1, 1, opt.NUM_CLUSTER, opt.RANK_ATOMS) 65 | bi = torch.squeeze(torch.sum(bi, 3)) 66 | 67 | bi = self.sc(bi) 68 | 69 | bi = bi.view(bs,h*w,-1) 70 | bi = bi.permute(0,2,1) 71 | bi = torch.squeeze(self.Avgpool(bi)) #* 784 72 | 73 | bi = torch.sqrt(F.relu(bi)) - torch.sqrt(F.relu(-bi)) # signed sqrt 74 | bi = F.normalize(bi, p=2, dim=1) 75 | 76 | y = self.Linear_predict(bi) 77 | return y 78 | 79 | if __name__ == '__main__': 80 | print(1) 81 | 82 | --------------------------------------------------------------------------------